Submit Info #60767

Problem Lang User Status Time Memory
Log of Formal Power Series lisp sansaqua AC 586 ms 53.29 MiB

ケース詳細
Name Status Time Memory
example_00 AC 9 ms 1.59 MiB
max_all_zero_00 AC 274 ms 26.22 MiB
max_random_00 AC 586 ms 53.20 MiB
max_random_01 AC 586 ms 53.20 MiB
max_random_02 AC 586 ms 53.29 MiB
max_random_03 AC 585 ms 53.23 MiB
max_random_04 AC 584 ms 53.28 MiB
near_262144_00 AC 294 ms 28.20 MiB
near_262144_01 AC 292 ms 27.16 MiB
near_262144_02 AC 412 ms 33.20 MiB
random_00 AC 541 ms 46.81 MiB
random_01 AC 572 ms 51.05 MiB
random_02 AC 71 ms 7.55 MiB
random_03 AC 557 ms 49.09 MiB
random_04 AC 496 ms 40.12 MiB
small_degree_00 AC 9 ms 1.57 MiB
small_degree_01 AC 9 ms 1.59 MiB
small_degree_02 AC 8 ms 1.57 MiB
small_degree_03 AC 8 ms 1.56 MiB
small_degree_04 AC 8 ms 1.68 MiB
small_degree_05 AC 8 ms 1.57 MiB
small_degree_06 AC 8 ms 1.66 MiB
small_degree_07 AC 9 ms 1.58 MiB
small_degree_08 AC 8 ms 1.61 MiB
small_degree_09 AC 9 ms 1.66 MiB

(in-package :cl-user) (eval-when (:compile-toplevel :load-toplevel :execute) (defparameter *opt* #+swank '(optimize (speed 3) (safety 2)) #-swank '(optimize (speed 3) (safety 0) (debug 0))) #+swank (ql:quickload '(:cl-debug-print :fiveam :cp/util) :silent t) #+swank (use-package :cp/util :cl-user) #-swank (set-dispatch-macro-character #\# #\> (lambda (s c p) (declare (ignore c p)) `(values ,(read s nil nil t)))) #+sbcl (dolist (f '(:popcnt :sse4)) (pushnew f sb-c:*backend-subfeatures*)) (setq *random-state* (make-random-state t))) #-swank (eval-when (:compile-toplevel) (setq *break-on-signals* '(and warning (not style-warning)))) #+swank (set-dispatch-macro-character #\# #\> #'cl-debug-print:debug-print-reader) (macrolet ((def (b) `(progn (deftype ,(intern (format nil "UINT~A" b)) () '(unsigned-byte ,b)) (deftype ,(intern (format nil "INT~A" b)) () '(signed-byte ,b)))) (define-int-types (&rest bits) `(progn ,@(mapcar (lambda (b) `(def ,b)) bits)))) (define-int-types 2 4 7 8 15 16 31 32 62 63 64)) (defconstant +mod+ 1000000007) (defmacro dbg (&rest forms) (declare (ignorable forms)) #+swank (if (= (length forms) 1) `(format *error-output* "~A => ~A~%" ',(car forms) ,(car forms)) `(format *error-output* "~A => ~A~%" ',forms `(,,@forms)))) (declaim (inline println)) (defun println (obj &optional (stream *standard-output*)) (let ((*read-default-float-format* (if (typep obj 'double-float) 'double-float *read-default-float-format*))) (prog1 (princ obj stream) (terpri stream)))) ;; BEGIN_INSERTED_CONTENTS (defpackage :cp/experimental/barrett (:use :cl) (:import-from :sb-ext #:truly-the) (:import-from :sb-c #:deftransform #:derive-type #:defoptimizer #:defknown #:movable #:foldable #:flushable #:commutative #:lvar-type #:lvar-value #:give-up-ir1-transform #:integer-type-numeric-bounds #:define-vop) (:import-from :sb-vm #:move #:inst #:eax-offset #:edx-offset #:any-reg #:control-stack #:unsigned-reg #:positive-fixnum #:fixnumize #:ea) (:import-from :sb-kernel #:specifier-type) (:import-from :sb-int #:explicit-check #:constant-arg) (:export #:fast-mod #:%himod) (:documentation "Provides Barrett reduction.")) (in-package :cp/experimental/barrett) (eval-when (:compile-toplevel :load-toplevel :execute) (defun derive-* (x y) (let ((high1 (nth-value 1 (integer-type-numeric-bounds (lvar-type x)))) (high2 (nth-value 1 (integer-type-numeric-bounds (lvar-type y))))) (specifier-type (if (and (integerp high1) (integerp high2)) `(integer 0 ,(ash (* high1 high2) -62)) `(integer 0))))) (defun derive-mod (modulus) (let ((high (nth-value 1 (integer-type-numeric-bounds (lvar-type modulus))))) (specifier-type (if (integerp high) `(integer 0 (,high)) `(integer 0))))) ;; *-high62 (defknown *-high62 ((unsigned-byte 62) (unsigned-byte 62)) (unsigned-byte 62) (movable foldable flushable commutative sb-c::always-translatable) :overwrite-fndb-silently t) (defoptimizer (*-high62 derive-type) ((x y)) (derive-* x y)) (define-vop (fast-*-high62/fixnum) (:translate *-high62) (:policy :fast-safe) (:args (x :scs (any-reg) :target eax) (y :scs (any-reg control-stack))) (:arg-types positive-fixnum positive-fixnum) (:temporary (:sc any-reg :offset eax-offset :from (:argument 0) :to :result) eax) (:temporary (:sc any-reg :offset edx-offset :target r :from :eval :to :result) edx) (:results (r :scs (any-reg))) (:result-types positive-fixnum) (:note "inline *-high62") (:vop-var vop) (:save-p :compute-only) (:generator 6 (move eax x) (inst mul eax y) (inst shl edx 1) (move r edx))) (define-vop (fast-c-*-high62-/fixnum) (:translate *-high62) (:policy :fast-safe) (:args (x :scs (any-reg) :target eax)) (:info y) (:arg-types positive-fixnum (:constant (unsigned-byte 62))) (:temporary (:sc any-reg :offset eax-offset :from (:argument 0) :to :result) eax) (:temporary (:sc any-reg :offset edx-offset :target r :from :eval :to :result) edx) (:results (r :scs (any-reg))) (:result-types positive-fixnum) (:note "inline constant *-high62") (:vop-var vop) (:save-p :compute-only) (:generator 6 (move eax x) (inst mul eax (sb-c:register-inline-constant :qword y)) (inst shl edx 1) (move r edx))) (defun *-high62 (x y) (declare (explicit-check)) (*-high62 x y)) ;; %himod (defknown %himod ((unsigned-byte 32) (unsigned-byte 31)) (unsigned-byte 31) (movable foldable flushable sb-c:always-translatable) :overwrite-fndb-silently t) (defoptimizer (%himod derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus)) (define-vop (fast-c-%himod) (:translate %himod) (:policy :fast-safe) (:args (x :scs (any-reg) :target r)) (:info m) (:arg-types positive-fixnum (:constant (unsigned-byte 31))) (:temporary (:sc any-reg :from :eval :to :result) y) (:results (r :scs (any-reg))) (:result-types positive-fixnum) (:note "inline constant %himod") (:vop-var vop) (:generator 4 (assert #.(if (find-symbol "GPR-TN-P" :sb-vm) `(funcall (intern "GPR-TN-P" :sb-vm) x) t)) (when (sb-c:tn-p m) (assert (sb-c:sc-is m sb-vm::immediate)) (setq m (sb-c::tn-value m))) (setq m (fixnumize m)) (move r x) (inst cmp r (- m 2)) (inst lea y #.(if (fboundp 'ea) `(sb-vm::ea (- m) r) `(,(find-symbol "MAKE-EA" :sb-vm) :dword :disp (- m) :base r))) (inst cmov :a r y))) ;; fast-mod (defknown fast-mod (integer unsigned-byte) unsigned-byte (movable foldable flushable) :overwrite-fndb-silently t) (defoptimizer (fast-mod derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus)) (defun fast-mod (number divisor) "Is equivalent to CL:MOD for integer arguments" (declare (explicit-check)) (mod number divisor)) (deftransform fast-mod ((number divisor) ((unsigned-byte 62) (constant-arg (unsigned-byte 31))) * :important t) "convert FAST-MOD to barrett reduction" (let ((mod (lvar-value divisor))) (if (<= mod 1) (give-up-ir1-transform) (let ((m (floor (ash 1 62) mod))) `(let* ((q (*-high62 number ,m)) (x (truly-the (signed-byte 32) (- number (* q ,mod))))) (if (< x ,mod) x (- x ,mod)) ;; Not so effective because branch prediction is usually correct? ;; (sb-ext:truly-the (mod ,mod) (%himod x ,mod)) )))))) (defpackage :cp/mod-power (:use :cl) (:export #:mod-power)) (in-package :cp/mod-power) (declaim (inline mod-power)) (defun mod-power (base power modulus) "Returns BASE^POWER mod MODULUS. Note: 0^0 = 1. BASE := integer POWER, MODULUS := non-negative fixnum" (declare ((integer 0 #.most-positive-fixnum) modulus power) (integer base)) (let ((base (mod base modulus)) (res (mod 1 modulus))) (declare ((integer 0 #.most-positive-fixnum) base res)) (loop while (> power 0) when (oddp power) do (setq res (mod (* res base) modulus)) do (setq base (mod (* base base) modulus) power (ash power -1))) res)) (defpackage :cp/mod-inverse (:use :cl) #+sbcl (:import-from #:sb-c #:defoptimizer #:lvar-type #:integer-type-numeric-bounds #:derive-type #:flushable #:foldable) #+sbcl (:import-from :sb-kernel #:specifier-type) (:export #:mod-inverse)) (in-package :cp/mod-inverse) #+sbcl (eval-when (:compile-toplevel :load-toplevel :execute) (sb-c:defknown %mod-inverse ((integer 0) (integer 1)) (integer 0) (flushable foldable) :overwrite-fndb-silently t) (sb-c:defknown mod-inverse (integer (integer 1)) (integer 0) (flushable foldable) :overwrite-fndb-silently t) (defun derive-mod (modulus) (let ((high (nth-value 1 (integer-type-numeric-bounds (lvar-type modulus))))) (specifier-type (if (integerp high) `(integer 0 (,high)) `(integer 0))))) (defoptimizer (%mod-inverse derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus)) (defoptimizer (mod-inverse derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus))) (defun %mod-inverse (integer modulus) (declare (optimize (speed 3) (safety 0)) #+sbcl (sb-ext:muffle-conditions sb-ext:compiler-note)) (macrolet ((frob (stype) `(let ((a integer) (b modulus) (u 1) (v 0)) (declare (,stype a b u v)) (loop until (zerop b) for quot = (floor a b) do (decf a (the ,stype (* quot b))) (rotatef a b) (decf u (the ,stype (* quot v))) (rotatef u v)) (if (< u 0) (+ u modulus) u)))) (typecase modulus ((unsigned-byte 31) (frob (signed-byte 32))) ((unsigned-byte 62) (frob (signed-byte 63))) (otherwise (frob integer))))) (declaim (inline mod-inverse)) (defun mod-inverse (integer modulus) "Solves ax = 1 mod m. Signals DIVISION-BY-ZERO when INTEGER and MODULUS are not coprime." (let* ((integer (mod integer modulus)) (result (%mod-inverse integer modulus))) (unless (or (= 1 (mod (* integer result) modulus)) (= 1 modulus)) (error 'division-by-zero :operands (list integer modulus) :operation 'mod-inverse)) result)) (defpackage :cp/ntt (:use :cl :cp/mod-inverse :cp/experimental/barrett) (:export #:define-ntt #:check-ntt-vector #:ntt-int #:ntt-vector #:+ntt-mod+) (:documentation "Provides fast number theoretic transform. Reference: https://github.com/ei1333/library/blob/master/math/fft/number-theoretic-transform-friendly-mod-int.cpp https://github.com/atcoder/ac-library/tree/master/atcoder")) (in-package :cp/ntt) (deftype ntt-int () '(unsigned-byte 31)) (deftype ntt-vector () '(simple-array ntt-int (*))) (eval-when (:compile-toplevel :load-toplevel :execute) (declaim (inline %tzcount)) (defun %tzcount (x) "Returns the number of the trailing zero bits. Note that (%TZCOUNT 0) = -1." (- (integer-length (logand x (- x))) 1)) (defun %mod-power (base exp modulus) (declare (ntt-int base exp modulus)) (let ((res 1)) (declare (ntt-int res)) (loop while (> exp 0) when (oddp exp) do (setq res (mod (* res base) modulus)) do (setq base (mod (* base base) modulus) exp (ash exp -1))) res)) (defun %mod-inverse (x modulus) (%mod-power x (- modulus 2) modulus)) (defun %calc-generator (modulus) "MODULUS must be prime." (declare (ntt-int modulus)) (assert (>= modulus 2)) (case modulus (2 1) (167772161 3) (469762049 3) (754974721 11) (998244353 3) (otherwise (let ((divs (make-array 20 :element-type 'ntt-int :initial-element 0)) (end 1) (x (floor (- modulus 1) 2))) (declare ((integer 0 #.most-positive-fixnum) x)) (setf (aref divs 0) 2) (loop while (evenp x) do (setq x (floor x 2))) (loop for i of-type ntt-int from 3 by 2 while (<= (* i i) x) when (zerop (mod x i)) do (setf (aref divs end) i) (incf end) (loop while (zerop (mod x i)) do (setq x (floor x i)))) (when (> x 1) (setf (aref divs end) x) (incf end)) (loop for g of-type ntt-int from 2 do (dotimes (i end (return-from %calc-generator g)) (when (= 1 (%mod-power g (floor (- modulus 1) (aref divs i)) modulus)) (return))))))))) (declaim (ftype (function * (values ntt-vector &optional)) %adjust-array)) (defun %adjust-array (vector length) "This function always copies VECTOR. (ANSI CL doesn't state whether CL:ADJUST-ARRAY should copy the array or not.)" (declare (optimize (speed 3)) (vector vector) ((mod #.array-dimension-limit) length)) (let ((vector (coerce vector 'ntt-vector))) (if (= (length vector) length) (copy-seq vector) (let ((res (make-array length :element-type 'ntt-int :initial-element 0))) (replace res vector) res)))) (defun check-ntt-vector (vector) (declare (optimize (speed 3)) (vector vector)) (let ((len (length vector))) (assert (and (zerop (logand len (- len 1))) ; power of two (typep len 'ntt-int))))) (defmacro define-ntt (modulus &key ntt inverse-ntt convolve &environment env) (assert (constantp modulus env)) (let* ((modulus #+sbcl (sb-int:constant-form-value modulus env) #-sbcl modulus) (ntt (or ntt (intern "NTT!"))) (inverse-ntt (or inverse-ntt (intern "INVERSE-NTT!"))) (convolve (or convolve (intern "CONVOLVE"))) (ntt-base (gensym "*NTT-BASE*")) (ntt-inv-base (gensym "*NTT-INV-BASE*")) (base-size (%tzcount (- modulus 1))) (root (%calc-generator modulus)) (modulus (sb-int:constant-form-value modulus env))) (declare (ntt-int modulus)) `(progn (declaim (ntt-vector ,ntt-base ,ntt-inv-base)) (sb-ext:define-load-time-global ,ntt-base (make-array ,base-size :element-type 'ntt-int)) (sb-ext:define-load-time-global ,ntt-inv-base (make-array ,base-size :element-type 'ntt-int)) (dotimes (i ,base-size) (setf (aref ,ntt-base i) (mod (- (%mod-power ,root (ash (- ,modulus 1) (- (+ i 2))) ,modulus)) ,modulus) (aref ,ntt-inv-base i) (%mod-inverse (aref ,ntt-base i) ,modulus))) (declaim (ftype (function * (values ntt-vector &optional)) ,ntt)) (defun ,ntt (vector) (declare (optimize (speed 3) (safety 0)) (vector vector)) (check-ntt-vector vector) (labels ((mod* (x y) (fast-mod (* x y) ,modulus)) (mod+ (x y) (%himod (+ x y) ,modulus)) (mod- (x y) (mod+ x (the ntt-int (- ,modulus y))))) (declare (inline mod* mod+ mod-)) (let* ((vector (coerce vector 'ntt-vector)) (len (length vector)) (base ,ntt-base)) (declare (ntt-vector vector base) (ntt-int len)) (when (<= len 1) (return-from ,ntt vector)) (loop for m of-type ntt-int = (ash len -1) then (ash m -1) while (> m 0) for w of-type ntt-int = 1 for k of-type ntt-int = 0 do (loop for s of-type ntt-int from 0 below len by (* 2 m) do (loop for i from s below (+ s m) for j from (+ s m) for x = (aref vector i) for y = (mod* (aref vector j) w) do (setf (aref vector i) (mod+ x y) (aref vector j) (mod- x y))) (incf k) (setq w (mod* w (aref base (%tzcount k)))))) vector))) (declaim (ftype (function * (values ntt-vector &optional)) ,inverse-ntt)) (defun ,inverse-ntt (vector &optional inverse) (declare (optimize (speed 3) (safety 0)) (vector vector)) (check-ntt-vector vector) (labels ((mod* (x y) (fast-mod (* x y) ,modulus)) (mod+ (x y) (%himod (+ x y) ,modulus)) (mod- (x y) (mod+ x (the ntt-int (- ,modulus y))))) (declare (inline mod* mod+ mod-)) (let* ((vector (coerce vector 'ntt-vector)) (len (length vector)) (base ,ntt-inv-base)) (declare (ntt-vector vector base) (ntt-int len)) (when (<= len 1) (return-from ,inverse-ntt vector)) (loop for m of-type ntt-int = 1 then (ash m 1) while (< m len) for w of-type ntt-int = 1 for k of-type ntt-int = 0 do (loop for s of-type ntt-int from 0 below len by (* 2 m) do (loop for i from s below (+ s m) for j from (+ s m) for x = (aref vector i) for y = (aref vector j) do (setf (aref vector i) (mod+ x y) (aref vector j) (mod* (mod- x y) w))) (incf k) (setq w (mod* w (aref base (%tzcount k)))))) (when inverse (let ((inv-len (mod-inverse len ,modulus))) (dotimes (i len) (setf (aref vector i) (mod* inv-len (aref vector i)))))) vector))) (declaim (ftype (function * (values ntt-vector &optional)) ,convolve)) (defun ,convolve (vector1 vector2) (declare (optimize (speed 3)) (vector vector1 vector2)) ;; TODO: if (EQ VECTOR1 VECTOR2) holds, the number of FFTs can be ;; reduced. (let* ((len1 (length vector1)) (len2 (length vector2)) (mul-len (max 0 (- (+ len1 len2) 1))) (vector1 (coerce vector1 'ntt-vector)) (vector2 (coerce vector2 'ntt-vector))) (declare (ntt-vector vector1 vector2) ((mod #.array-dimension-limit) mul-len)) (when (or (zerop len1) (zerop len2)) (return-from ,convolve (make-array 0 :element-type 'ntt-int))) ;; naive convolution (when (<= (min len1 len2) 64) (let ((res (make-array mul-len :element-type 'ntt-int :initial-element 0))) (declare (optimize (speed 3) (safety 0))) (dotimes (d mul-len) ;; 0 <= i <= deg1, 0 <= j <= deg2 (loop with coef of-type ntt-int = 0 for i from (max 0 (- d (- len2 1))) to (min d (- len1 1)) for j = (- d i) do (setq coef (fast-mod (+ coef (* (aref vector1 i) (aref vector2 j))) ,modulus)) finally (setf (aref res d) coef))) (return-from ,convolve res))) (let* (;; power of two ceiling (required-len (ash 1 (integer-length (max 0 (- mul-len 1))))) (vector1 (,ntt (%adjust-array vector1 required-len))) (vector2 (,ntt (%adjust-array vector2 required-len)))) (dotimes (i required-len) (setf (aref vector1 i) (fast-mod (* (aref vector1 i) (aref vector2 i)) ,modulus))) (adjust-array (,inverse-ntt vector1 t) mul-len))))))) (defconstant +ntt-mod+ 998244353) #+(or) (define-ntt +ntt-mod+) (defpackage :cp/polynomial-ntt (:use :cl :cp/ntt :cp/mod-inverse :cp/mod-power) (:export #:poly-multiply #:poly-inverse #:poly-floor #:poly-mod #:poly-sub #:poly-add #:multipoint-eval #:poly-total-prod #:chirp-z #:bostan-mori #:poly-differentiate1 #:poly-integrate #:poly-log)) (in-package :cp/polynomial-ntt) ;; TODO: integrate with cp/polynomial (define-ntt +ntt-mod+ :convolve poly-multiply) (declaim (inline %adjust)) (defun %adjust (vector size) (declare (ntt-vector vector)) (if (or (null size) (= size (length vector))) vector (let ((res (make-array size :element-type 'ntt-int :initial-element 0))) (replace res vector) res))) (declaim (inline %power-of-two-ceiling)) (defun %power-of-two-ceiling (x) (declare (ntt-int x)) (ash 1 (integer-length (- x 1)))) ;; (declaim (ftype (function * (values ntt-vector &optional)) poly-inverse)) ;; (defun poly-inverse (poly &optional result-length) ;; (declare (optimize (speed 3)) ;; (vector poly) ;; ((or null fixnum) result-length)) ;; (let* ((poly (coerce poly 'ntt-vector)) ;; (n (length poly))) ;; (declare (ntt-vector poly)) ;; (when (or (zerop n) ;; (zerop (aref poly 0))) ;; (error 'division-by-zero ;; :operation #'poly-inverse ;; :operands (list poly))) ;; (let ((res (make-array 1 ;; :element-type 'ntt-int ;; :initial-element (mod-inverse (aref poly 0) +ntt-mod+))) ;; (result-length (or result-length n))) ;; (declare (ntt-vector res)) ;; (loop for i of-type ntt-int = 1 then (ash i 1) ;; while (< i result-length) ;; for decr = (poly-multiply (poly-multiply res res) ;; (subseq poly 0 (min (length poly) (* 2 i)))) ;; for decr-len = (length decr) ;; do (setq res (adjust-array res (* 2 i) :initial-element 0)) ;; (dotimes (j (* 2 i)) ;; (setf (aref res j) ;; (mod (the ntt-int ;; (+ (mod (* 2 (aref res j)) +ntt-mod+) ;; (if (>= j decr-len) 0 (- +ntt-mod+ (aref decr j))))) ;; +ntt-mod+)))) ;; (adjust-array res result-length)))) ;; Reference: https://opt-cp.com/fps-fast-algorithms/ (declaim (ftype (function * (values ntt-vector &optional)) poly-inverse)) (defun poly-inverse (poly &optional result-length) (declare #.cl-user::*opt* (vector poly) ((or null fixnum) result-length)) (let* ((poly (coerce poly 'ntt-vector)) (n (length poly))) (declare (ntt-vector poly)) (when (or (zerop n) (zerop (aref poly 0))) (error 'division-by-zero :operation #'poly-inverse :operands (list poly))) (let* ((result-length (or result-length n)) (res (make-array 1 :element-type 'ntt-int :initial-element (mod-inverse (aref poly 0) +ntt-mod+)))) (declare (ntt-vector res)) (loop for i of-type ntt-int = 1 then (ash i 1) while (< i result-length) for f of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int :initial-element 0) for g of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int :initial-element 0) do (replace f poly :end2 (min n (* 2 i))) (replace g res) (ntt! f) (ntt! g) (dotimes (j (* 2 i)) (setf (aref f j) (mod (* (aref g j) (aref f j)) +ntt-mod+))) (inverse-ntt! f) (replace f f :start1 0 :end1 i :start2 i :end2 (* 2 i)) (fill f 0 :start i :end (* 2 i)) (ntt! f) (dotimes (j (* 2 i)) (setf (aref f j) (mod (* (aref g j) (aref f j)) +ntt-mod+))) (inverse-ntt! f) (let ((inv-len (mod-inverse (* 2 i) +ntt-mod+))) (setq inv-len (mod (* inv-len (- +ntt-mod+ inv-len)) +ntt-mod+)) (dotimes (j i) (setf (aref f j) (mod (* inv-len (aref f j)) +ntt-mod+))) (setq res (%adjust res (* 2 i))) (replace res f :start1 i))) (%adjust res result-length)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-floor)) (defun poly-floor (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let* ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector)) (deg1 (+ 1 (or (position 0 poly1 :from-end t :test-not #'eql) -1))) (deg2 (+ 1 (or (position 0 poly2 :from-end t :test-not #'eql) -1)))) (when (> deg2 deg1) (return-from poly-floor (make-array 0 :element-type 'ntt-int))) (setq poly1 (nreverse (subseq poly1 0 deg1)) poly2 (nreverse (subseq poly2 0 deg2))) (let* ((res-len (+ 1 (- deg1 deg2))) (res (adjust-array (poly-multiply poly1 (poly-inverse poly2 res-len)) res-len))) (nreverse res)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-sub)) (defun poly-sub (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let* ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector)) (len (max (length poly1) (length poly2))) (res (make-array len :element-type 'ntt-int :initial-element 0))) (replace res poly1) (dotimes (i (length poly2)) (let ((value (+ (aref res i) (the ntt-int (- +ntt-mod+ (aref poly2 i)))))) (setf (aref res i) (if (>= value +ntt-mod+) (- value +ntt-mod+) value)))) (let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1)))) (adjust-array res end)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-add)) (defun poly-add (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let* ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector)) (len (max (length poly1) (length poly2))) (res (make-array len :element-type 'ntt-int :initial-element 0))) (replace res poly1) (dotimes (i (length poly2)) (let ((value (+ (aref res i) (aref poly2 i)))) (setf (aref res i) (if (>= value +ntt-mod+) (- value +ntt-mod+) value)))) (let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1)))) (adjust-array res end)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-mod)) (defun poly-mod (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector))) (when (loop for x across poly1 always (zerop x)) (return-from poly-mod (make-array 0 :element-type 'ntt-int))) (let* ((res (poly-sub poly1 (poly-multiply (poly-floor poly1 poly2) poly2))) (end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1)))) (subseq res 0 end)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-total-prod)) (defun poly-total-prod (polys) "Returns the total polynomial product: polys[0] * polys[1] * ... * polys[n-1]." (declare (vector polys)) (let* ((n (length polys)) (dp (make-array n :element-type t))) (declare ((mod #.array-dimension-limit) n)) (when (zerop n) (return-from poly-total-prod (make-array 1 :element-type 'ntt-int :initial-element 1))) (replace dp polys) (loop for width of-type (mod #.array-dimension-limit) = 1 then (ash width 1) while (< width n) do (loop for i of-type (mod #.array-dimension-limit) from 0 by (* width 2) while (< (+ i width) n) do (setf (aref dp i) (poly-multiply (aref dp i) (aref dp (+ i width)))))) (coerce (the vector (aref dp 0)) 'ntt-vector))) (declaim (ftype (function * (values ntt-vector &optional)) multipoint-eval)) (defun multipoint-eval (poly points) "The length of POINTS must be a power of two." (declare (optimize (speed 3)) (vector poly points) #+sbcl (sb-ext:muffle-conditions style-warning)) (check-ntt-vector points) (let* ((poly (coerce poly 'ntt-vector)) (points (coerce points 'ntt-vector)) (len (length points)) (table (make-array (max 0 (- (* 2 len) 1)) :element-type 'ntt-vector)) (res (make-array len :element-type 'ntt-int))) (unless (zerop len) (labels ((%build (l r pos) (declare ((mod #.array-dimension-limit) l r pos)) (if (= (- r l) 1) (let ((lin (make-array 2 :element-type 'ntt-int))) (setf (aref lin 0) (- +ntt-mod+ (aref points l)) ;; NOTE: non-zero (aref lin 1) 1) (setf (aref table pos) lin)) (let ((mid (ash (+ l r) -1))) (%build l mid (+ 1 (* pos 2))) (%build mid r (+ 2 (* pos 2))) (setf (aref table pos) (poly-multiply (aref table (+ 1 (* pos 2))) (aref table (+ 2 (* pos 2))))))))) (%build 0 len 0)) (labels ((%eval (poly l r pos) (declare ((mod #.array-dimension-limit) l r pos)) (if (= (- r l) 1) (let ((tmp (poly-mod poly (aref table pos)))) (setf (aref res l) (if (zerop (length tmp)) 0 (aref tmp 0)))) (let ((mid (ash (+ l r) -1))) (%eval (poly-mod poly (aref table (+ (* 2 pos) 1))) l mid (+ (* 2 pos) 1)) (%eval (poly-mod poly (aref table (+ (* 2 pos) 2))) mid r (+ (* 2 pos) 2)))))) (%eval poly 0 len 0))) res)) ;; not tested (declaim (ftype (function * (values ntt-vector &optional)) chirp-z)) (defun chirp-z (poly base length) "Does multipoint evaluation of POLY with powers of BASE: P(base^0), P(base^1), ..., P(base^(length-1)). BASE must be coprime with modulus. Time complexity is O((N+MOD)log(N+MOD)). Reference: https://codeforces.com/blog/entry/83532" (declare (optimize (speed 3)) (vector poly) (ntt-int base length)) (when (zerop (length poly)) (return-from chirp-z (make-array length :element-type 'ntt-int :initial-element 0))) (let* ((poly (coerce poly 'ntt-vector)) (binv (mod-inverse base +ntt-mod+)) (n (length poly)) (m (max length n)) (n+m (+ n m)) (cs (make-array n :element-type 'ntt-int :initial-element 0)) (ds (make-array n+m :element-type 'ntt-int :initial-element 0))) (declare (ntt-int n m n+m)) (dotimes (i n) (setf (aref cs i) (mod (* (aref poly (- n 1 i)) (mod-power binv (ash (* (- n 1 i) (- n 2 i)) -1) +ntt-mod+)) +ntt-mod+))) (dotimes (i n+m) (setf (aref ds i) (mod-power base (ash (* i (- i 1)) -1) +ntt-mod+))) (let ((result (subseq (poly-multiply cs ds) (- n 1) (+ (- n 1) length)))) (dotimes (i length) (setf (aref result i) (mod (* (aref result i) (mod-power binv (ash (* i (- i 1)) -1) +ntt-mod+)) +ntt-mod+))) result))) (declaim (ftype (function * (values ntt-int &optional)) bostan-mori)) (defun bostan-mori (index num denom) "Returns [x^index](num(x)/denom(x)). Reference: https://arxiv.org/abs/2008.08822 https://qiita.com/ryuhe1/items/da5acbcce4ac1911f47 (Japanese)" (declare (optimize (speed 3)) (unsigned-byte index) (vector num denom)) (labels ((even (p) (let ((res (make-array (ceiling (length p) 2) :element-type 'ntt-int))) (dotimes (i (length res)) (setf (aref res i) (aref p (* 2 i)))) res)) (odd (p) (let ((res (make-array (floor (length p) 2) :element-type 'ntt-int))) (dotimes (i (length res)) (setf (aref res i) (aref p (+ 1 (* 2 i))))) res)) (negate (p) (let ((res (copy-seq p))) (loop for i from 1 below (length res) by 2 do (setf (aref res i) (if (zerop (aref res i)) 0 (- +ntt-mod+ (aref res i))))) res))) (let ((num (coerce num 'ntt-vector)) (denom (coerce denom 'ntt-vector))) (when (or (zerop (length denom)) (zerop (aref denom 0))) (error 'division-by-zero :operands (list num denom) :operation 'bostan-mori)) (loop while (> index 0) for denom- = (negate denom) for u = (poly-multiply num denom-) when (evenp index) do (setq num (even u)) else do (setq num (odd u)) do (setq denom (even (poly-multiply denom denom-)) index (ash index -1)) finally (return (rem (* (if (zerop (length num)) 0 (aref num 0)) (mod-inverse (aref denom 0) +ntt-mod+)) +ntt-mod+)))))) (declaim (inline poly-differentiate!)) (defun poly-differentiate! (p) "Returns the derivative of P." (declare (vector p)) (let ((p (coerce p 'ntt-vector))) (when (zerop (length p)) (return-from poly-differentiate! p)) (dotimes (i (- (length p) 1)) (declare (ntt-int i)) (let ((coef (mod (* (aref p (+ i 1)) (+ i 1)) +ntt-mod+))) (declare ((integer 0 #.most-positive-fixnum) coef)) (setf (aref p i) coef))) (let ((end (+ 1 (or (position 0 p :from-end t :end (- (length p) 1) :test-not #'eql) -1)))) (subseq p 0 end)))) (declaim (ntt-vector *inv*)) (defparameter *inv* (make-array 2 :element-type 'ntt-int :initial-contents '(0 1))) (defun fill-inv! (new-size) (declare (optimize (speed 3)) ((mod #.array-dimension-limit) new-size)) (let* ((old-size (length *inv*)) (new-size (%power-of-two-ceiling (max old-size new-size))) (inv (adjust-array *inv* new-size))) (declare (ntt-vector inv)) (loop for x from old-size below new-size do (setf (aref inv x) (- +ntt-mod+ (mod (* (aref inv (rem +ntt-mod+ x)) (floor +ntt-mod+ x)) +ntt-mod+)))) (setq *inv* inv))) (declaim (inline poly-integrate)) (defun poly-integrate (p) "Returns an indefinite integral of P. Assumes the integration constant to be zero." (declare (vector p)) (let* ((p (coerce p 'ntt-vector)) (n (length p))) (when (zerop n) (return-from poly-integrate (make-array 0 :element-type 'ntt-int))) (fill-inv! (+ n 1)) (let ((result (make-array (+ n 1) :element-type 'ntt-int :initial-element 0)) (inv *inv*)) (dotimes (i n) (setf (aref result (+ i 1)) (mod (* (the fixnum (aref p i)) (aref inv (+ i 1))) +ntt-mod+))) result))) (declaim (ftype (function * (values ntt-vector &optional)) poly-log)) (defun poly-log (poly &optional result-length) (declare #.cl-user::*opt* (vector poly) ((or null (integer 1 (#.array-dimension-limit))) result-length)) (let* ((poly (coerce poly 'ntt-vector)) (length (or result-length (length poly)))) (assert (and (> length 0) (= 1 (aref poly 0)))) (let ((res (poly-integrate (%adjust (poly-multiply (poly-differentiate! (copy-seq poly)) (poly-inverse poly length)) (- length 1))))) (%adjust res result-length)))) (defpackage :cp/read-fixnum (:use :cl) (:export #:read-fixnum)) (in-package :cp/read-fixnum) (declaim (ftype (function * (values fixnum &optional)) read-fixnum)) (defun read-fixnum (&optional (in *standard-input*)) "NOTE: cannot read -2^62" (declare #.cl-user::*opt*) (macrolet ((%read-byte () `(the (unsigned-byte 8) #+swank (char-code (read-char in nil #\Nul)) #-swank (sb-impl::ansi-stream-read-byte in nil #.(char-code #\Nul) nil)))) (let* ((minus nil) (result (loop (let ((byte (%read-byte))) (cond ((<= 48 byte 57) (return (- byte 48))) ((zerop byte) ; #\Nul (error "Read EOF or #\Nul.")) ((= byte #.(char-code #\-)) (setq minus t))))))) (declare ((integer 0 #.most-positive-fixnum) result)) (loop (let* ((byte (%read-byte))) (if (<= 48 byte 57) (setq result (+ (- byte 48) (* 10 (the (integer 0 #.(floor most-positive-fixnum 10)) result)))) (return (if minus (- result) result)))))))) (defpackage :cp/println-sequence (:use :cl) (:export #:println-sequence)) (in-package :cp/println-sequence) (declaim (inline println-sequence)) (defun println-sequence (sequence &key (out *standard-output*) (key #'identity)) (let ((init t)) (sequence:dosequence (x sequence) (if init (setq init nil) (write-char #\ out)) (princ (funcall key x) out)) (terpri out))) ;; BEGIN_USE_PACKAGE (eval-when (:compile-toplevel :load-toplevel :execute) (use-package :cp/println-sequence :cl-user)) (eval-when (:compile-toplevel :load-toplevel :execute) (use-package :cp/read-fixnum :cl-user)) (eval-when (:compile-toplevel :load-toplevel :execute) (use-package :cp/polynomial-ntt :cl-user)) (in-package :cl-user) ;;; ;;; Body ;;; (defun main () (declare #.*opt*) (let* ((n (read)) (as (make-array n :element-type 'uint31 :initial-element 0))) (dotimes (i n) (setf (aref as i) (read-fixnum))) (let ((res (poly-log as n))) (write-string (with-output-to-string (*standard-output* nil :element-type 'base-char) (println-sequence res)))))) #-swank (main) ;;; ;;; Test ;;; #+swank (progn (defparameter *lisp-file-pathname* (uiop:current-lisp-file-pathname)) (setq *default-pathname-defaults* (uiop:pathname-directory-pathname *lisp-file-pathname*)) (uiop:chdir *default-pathname-defaults*) (defparameter *dat-pathname* (uiop:merge-pathnames* "test.dat" *lisp-file-pathname*)) (defparameter *problem-url* "PROBLEM_URL_TO_BE_REPLACED")) #+swank (defun gen-dat () (uiop:with-output-file (out *dat-pathname* :if-exists :supersede) (format out ""))) #+swank (defun bench (&optional (out (make-broadcast-stream))) (time (run *dat-pathname* out))) #+(and sbcl (not swank)) (eval-when (:compile-toplevel) (when sb-c::*undefined-warnings* (error "undefined warnings: ~{~A~^ ~}" sb-c::*undefined-warnings*)))