Submit Info #67532

Problem Lang User Status Time Memory
Partition Function lisp sansaqua AC 240 ms 32.95 MiB

ケース詳細
Name Status Time Memory
0_00 AC 25 ms 7.57 MiB
100000_00 AC 73 ms 13.25 MiB
10000_00 AC 29 ms 8.23 MiB
1000_00 AC 25 ms 7.68 MiB
100_00 AC 24 ms 7.57 MiB
1_00 AC 24 ms 7.52 MiB
200000_00 AC 127 ms 18.95 MiB
300000_00 AC 229 ms 27.64 MiB
400000_00 AC 234 ms 30.38 MiB
500000_00 AC 240 ms 32.95 MiB
example_00 AC 25 ms 7.66 MiB

(in-package :cl-user) (eval-when (:compile-toplevel :load-toplevel :execute) (defparameter *opt* #+swank '(optimize (speed 3) (safety 2)) #-swank '(optimize (speed 3) (safety 0) (debug 0))) #+swank (ql:quickload '(:cl-debug-print :fiveam :cp/util) :silent t) #+swank (use-package :cp/util :cl-user) #-swank (set-dispatch-macro-character #\# #\> (lambda (s c p) (declare (ignore c p)) `(values ,(read s nil nil t)))) #+sbcl (dolist (f '(:popcnt :sse4)) (pushnew f sb-c:*backend-subfeatures*)) (setq *random-state* (make-random-state t))) #-swank (eval-when (:compile-toplevel) (setq *break-on-signals* '(and warning (not style-warning)))) #+swank (set-dispatch-macro-character #\# #\> #'cl-debug-print:debug-print-reader) (macrolet ((def (b) `(progn (deftype ,(intern (format nil "UINT~A" b)) () '(unsigned-byte ,b)) (deftype ,(intern (format nil "INT~A" b)) () '(signed-byte ,b)))) (define-int-types (&rest bits) `(progn ,@(mapcar (lambda (b) `(def ,b)) bits)))) (define-int-types 2 4 7 8 15 16 31 32 62 63 64)) (defconstant +mod+ 998244353) (defmacro dbg (&rest forms) (declare (ignorable forms)) #+swank (if (= (length forms) 1) `(format *error-output* "~A => ~A~%" ',(car forms) ,(car forms)) `(format *error-output* "~A => ~A~%" ',forms `(,,@forms)))) (declaim (inline println)) (defun println (obj &optional (stream *standard-output*)) (let ((*read-default-float-format* (if (typep obj 'double-float) 'double-float *read-default-float-format*))) (prog1 (princ obj stream) (terpri stream)))) ;; BEGIN_INSERTED_CONTENTS (defpackage :cp/barrett (:use :cl) (:import-from #:sb-ext #:truly-the) (:import-from #:sb-c #:deftransform #:derive-type #:defoptimizer #:defknown #:movable #:foldable #:flushable #:commutative #:always-translatable #:lvar-type #:lvar-value #:give-up-ir1-transform #:integer-type-numeric-bounds #:define-vop #:tn-offset) (:import-from #:sb-vm #:move #:inst #:rax-offset #:rdx-offset #:temp-reg-tn #:any-reg #:control-stack #:unsigned-reg #:positive-fixnum #:fixnumize #:ea) (:import-from #:sb-kernel #:specifier-type) (:import-from #:sb-int #:explicit-check #:constant-arg) (:export #:fast-mod #:%himod #:%lomod) (:documentation "Provides Barrett reduction.")) (in-package :cp/barrett) (eval-when (:compile-toplevel :load-toplevel :execute) (defun derive-* (x y) (let ((high1 (nth-value 1 (integer-type-numeric-bounds (lvar-type x)))) (high2 (nth-value 1 (integer-type-numeric-bounds (lvar-type y))))) (specifier-type (if (and (integerp high1) (integerp high2)) `(integer 0 ,(ash (* high1 high2) -62)) `(integer 0))))) (defun derive-mod (modulus) (let ((high (nth-value 1 (integer-type-numeric-bounds (lvar-type modulus))))) (specifier-type (if (integerp high) `(integer 0 (,high)) `(integer 0))))) (defun gpr-tn-p (x) (declare (ignorable x)) #.(if (find-symbol "GPR-TN-P" :sb-vm) `(funcall (intern "GPR-TN-P" :sb-vm) x) t))) ;; *-high62 (eval-when (:compile-toplevel :load-toplevel :execute) (defknown *-high62 ((unsigned-byte 62) (unsigned-byte 62)) (unsigned-byte 62) (movable foldable flushable commutative always-translatable) :overwrite-fndb-silently t) (defoptimizer (*-high62 derive-type) ((x y)) (derive-* x y)) (define-vop (fast-*-high62/fixnum) (:translate *-high62) (:policy :fast-safe) (:args (x :scs (any-reg) :target rax) (y :scs (any-reg control-stack))) (:arg-types positive-fixnum positive-fixnum) (:temporary (:sc any-reg :offset rax-offset :from (:argument 0) :to :result) rax) (:temporary (:sc any-reg :offset rdx-offset :target r :from :eval :to :result) rdx) (:results (r :scs (any-reg))) (:result-types positive-fixnum) (:note "inline *-high62") (:vop-var vop) (:save-p :compute-only) (:generator 6 (move rax x) (inst mul rax y) (inst shl rdx 1) (move r rdx))) ;; NOTE: I'm not using it for now because registered constant is slower ;; (define-vop (fast-c-*-high62-/fixnum) ;; (:translate *-high62) ;; (:policy :fast-safe) ;; (:args (x :scs (any-reg) :target rax)) ;; (:info y) ;; (:arg-types positive-fixnum (:constant (unsigned-byte 62))) ;; (:temporary (:sc any-reg :offset rax-offset ;; :from (:argument 0) :to :result) ;; rax) ;; (:temporary (:sc any-reg :offset rdx-offset :target r ;; :from :eval :to :result) ;; rdx) ;; (:results (r :scs (any-reg))) ;; (:result-types positive-fixnum) ;; (:note "inline constant *-high62") ;; (:vop-var vop) ;; (:save-p :compute-only) ;; (:generator 5 ;; (move rax x) ;; (inst mul rax (sb-c:register-inline-constant :qword (fixnumize y))) ;; (inst shl rdx 1) ;; (move r rdx))) (defun *-high62 (x y) (declare (explicit-check)) (*-high62 x y))) ;; %himod (eval-when (:compile-toplevel :load-toplevel :execute) (defknown %himod ((unsigned-byte 32) (unsigned-byte 31)) (unsigned-byte 31) (movable foldable flushable always-translatable) :overwrite-fndb-silently t) (defoptimizer (%himod derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus)) (define-vop (fast-c-%himod) (:translate %himod) (:policy :fast-safe) (:args (x :scs (any-reg) :target r)) (:info m) (:arg-types positive-fixnum (:constant (unsigned-byte 31))) (:temporary (:sc any-reg :from :eval :to :result ;; FIXME: hack to avoid collision of X and Y :offset #.(tn-offset temp-reg-tn)) y) (:results (r :scs (any-reg))) (:result-types positive-fixnum) (:note "inline constant %himod") (:vop-var vop) (:generator 4 ;; maybe verbose (assert (gpr-tn-p x)) (assert (not (sb-c:location= x y))) (when (sb-c:tn-p m) (assert (sb-c:sc-is m sb-vm::immediate)) (setq m (sb-c::tn-value m))) (setq m (fixnumize m)) (move r x) (inst cmp r (- m 2)) (inst lea y #.(if (fboundp 'ea) `(sb-vm::ea (- m) r) `(,(find-symbol "MAKE-EA" :sb-vm) :dword :disp (- m) :base r))) (inst cmov :a r y)))) ;; %lomod (eval-when (:compile-toplevel :load-toplevel :execute) (defknown %lomod ((signed-byte 32) (unsigned-byte 31)) (unsigned-byte 31) (movable foldable flushable always-translatable) :overwrite-fndb-silently t) (defoptimizer (%lomod derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus)) (define-vop (fast-c-%lomod) (:translate %lomod) (:policy :fast-safe) (:args (x :scs (any-reg) :target r)) (:info m) (:arg-types fixnum (:constant (unsigned-byte 31))) (:temporary (:sc any-reg :from :eval :to :result :offset #.(tn-offset temp-reg-tn)) y) (:results (r :scs (any-reg))) (:result-types positive-fixnum) (:note "inline constant %lomod") (:vop-var vop) (:generator 4 ;; maybe verbose (assert (gpr-tn-p x)) (assert (not (sb-c:location= x y))) (when (sb-c:tn-p m) (assert (sb-c:sc-is m sb-vm::immediate)) (setq m (sb-c::tn-value m))) (setq m (fixnumize m)) (move r x) ;; TODO: this instruction is rarely necessary because %LOMOD tends to be ;; called right after subtraction. (inst or r r) (inst lea y #.(if (fboundp 'ea) `(sb-vm::ea m r) `(,(find-symbol "MAKE-EA" :sb-vm) :dword :disp m :base r))) (inst cmov :l r y)))) ;; fast-mod (eval-when (:compile-toplevel :load-toplevel :execute) (defknown fast-mod (integer unsigned-byte) unsigned-byte (movable foldable flushable) :overwrite-fndb-silently t) (defoptimizer (fast-mod derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus)) (defun fast-mod (number divisor) "Is equivalent to CL:MOD for integer arguments" (declare (explicit-check)) (mod number divisor)) (deftransform fast-mod ((number divisor) ((unsigned-byte 62) (constant-arg (unsigned-byte 31))) * :important t) "convert FAST-MOD to barrett reduction" (let ((mod (lvar-value divisor))) (if (<= mod 1) (give-up-ir1-transform) (let ((m (floor (ash 1 62) mod))) `(let* ((q (*-high62 number ,m)) (x (truly-the (signed-byte 32) (- number (* q ,mod))))) (if (< x ,mod) x (- x ,mod)) ;; Not so effective because branch prediction is usually correct? ;; (sb-ext:truly-the (mod ,mod) (%himod x ,mod)) )))))) (defpackage :cp/static-mod (:use :cl) (:export #:+mod+)) (in-package :cp/static-mod) (defconstant +mod+ (if (boundp 'cl-user::+mod+) (symbol-value 'cl-user::+mod+) 998244353)) (defpackage :cp/mod-inverse (:use :cl) #+sbcl (:import-from #:sb-c #:defoptimizer #:lvar-type #:integer-type-numeric-bounds #:derive-type #:flushable #:foldable) #+sbcl (:import-from :sb-kernel #:specifier-type) (:export #:mod-inverse)) (in-package :cp/mod-inverse) #+sbcl (eval-when (:compile-toplevel :load-toplevel :execute) (sb-c:defknown %mod-inverse ((integer 0) (integer 1)) (integer 0) (flushable foldable) :overwrite-fndb-silently t) (sb-c:defknown mod-inverse (integer (integer 1)) (integer 0) (flushable foldable) :overwrite-fndb-silently t) (defun derive-mod (modulus) (let ((high (nth-value 1 (integer-type-numeric-bounds (lvar-type modulus))))) (specifier-type (if (integerp high) `(integer 0 (,high)) `(integer 0))))) (defoptimizer (%mod-inverse derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus)) (defoptimizer (mod-inverse derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus))) (defun %mod-inverse (integer modulus) (declare (optimize (speed 3) (safety 0)) #+sbcl (sb-ext:muffle-conditions sb-ext:compiler-note)) (macrolet ((frob (stype) `(let ((a integer) (b modulus) (u 1) (v 0)) (declare (,stype a b u v)) (loop until (zerop b) for quot = (floor a b) do (decf a (the ,stype (* quot b))) (rotatef a b) (decf u (the ,stype (* quot v))) (rotatef u v)) (if (< u 0) (+ u modulus) u)))) (typecase modulus ((unsigned-byte 31) (frob (signed-byte 32))) ((unsigned-byte 62) (frob (signed-byte 63))) (otherwise (frob integer))))) (declaim (inline mod-inverse)) (defun mod-inverse (integer modulus) "Solves ax = 1 mod m. Signals DIVISION-BY-ZERO when INTEGER and MODULUS are not coprime." (let* ((integer (mod integer modulus)) (result (%mod-inverse integer modulus))) (unless (or (= 1 (mod (* integer result) modulus)) (= 1 modulus)) (error 'division-by-zero :operands (list integer modulus) :operation 'mod-inverse)) result)) (defpackage :cp/mod-power (:use :cl) (:export #:mod-power)) (in-package :cp/mod-power) (declaim (inline mod-power)) (defun mod-power (base power modulus) "Returns BASE^POWER mod MODULUS. Note: 0^0 = 1. BASE := integer POWER, MODULUS := non-negative fixnum" (declare ((integer 0 #.most-positive-fixnum) modulus power) (integer base)) (let ((base (mod base modulus)) (res (mod 1 modulus))) (declare ((integer 0 #.most-positive-fixnum) base res)) (loop while (> power 0) when (oddp power) do (setq res (mod (* res base) modulus)) do (setq base (mod (* base base) modulus) power (ash power -1))) res)) (defpackage :cp/zeta-integer (:use :cl) (:export #:divisor-transform! #:inverse-divisor-transform! #:multiple-transform! #:inverse-multiple-transform!) (:documentation "Provides fast zeta/Moebius transforms w.r.t. divisor or multiple. Time complexity is O(nloglog(n)).")) (in-package :cp/zeta-integer) (declaim (inline divisor-transform!)) (defun divisor-transform! (vector &optional (op+ #'+) (handle-zero t)) "Sets each VECTOR[i] to the sum of VECTOR[d] for all the divisors d of i. Ignores VECTOR[0] when HANDLE-ZERO is NIL." (declare (vector vector)) (let* ((n (length vector)) (sieve (make-array n :element-type 'bit :initial-element 1))) (when handle-zero (loop for i from 1 below n do (setf (aref vector 0) (funcall op+ (aref vector 0) (aref vector i))))) (loop for p from 2 below n when (= 1 (sbit sieve p)) do (loop for k from 1 below (ceiling n p) for pmult of-type fixnum = (* k p) do (setf (sbit sieve pmult) 0) (setf (aref vector pmult) (funcall op+ (aref vector pmult) (aref vector k))))) vector)) (declaim (inline inverse-divisor-transform!)) (defun inverse-divisor-transform! (vector &optional (op- #'-) (handle-zero t)) "Does the inverse transform of DIVISOR-TRANSFORM!. Ignores VECTOR[0] when HANDLE-ZERO is NIL." (declare (vector vector)) (let* ((n (length vector)) (sieve (make-array n :element-type 'bit :initial-element 1))) (loop for p from 2 below n when (= 1 (sbit sieve p)) do (loop for k from (- (ceiling n p) 1) downto 1 for pmult of-type fixnum = (* k p) do (setf (sbit sieve pmult) 0) (setf (aref vector pmult) (funcall op- (aref vector pmult) (aref vector k))))) (when handle-zero (loop for i from 1 below n do (setf (aref vector 0) (funcall op- (aref vector 0) (aref vector i))))) vector)) (declaim (inline multiple-transform!)) (defun multiple-transform! (vector &optional (op+ #'+) (handle-zero t)) "Sets each VECTOR[i] to the sum of VECTOR[m] for all the multiples m of i. (To be precise, all the multiples smaller than the length of VECTOR.) Ignores VECTOR[0] when HANDLE-ZERO is NIL." (declare (vector vector)) (let* ((n (length vector)) (sieve (make-array n :element-type 'bit :initial-element 1))) (loop for p from 2 below n when (= 1 (sbit sieve p)) do (loop for k from (- (ceiling n p) 1) downto 1 for pmult of-type fixnum = (* k p) do (setf (sbit sieve pmult) 0) (setf (aref vector k) (funcall op+ (aref vector k) (aref vector pmult))))) (when handle-zero (loop for i from 1 below n do (setf (aref vector i) (funcall op+ (aref vector 0) (aref vector i))))) vector)) (declaim (inline inverse-multiple-transform!)) (defun inverse-multiple-transform! (vector &optional (op- #'-) (handle-zero t)) "Does the inverse transform of MULTIPLE-TRANSFORM!. Ignores VECTOR[0] when HANDLE-ZERO is NIL." (declare (vector vector)) (let* ((n (length vector)) (sieve (make-array n :element-type 'bit :initial-element 1))) (when handle-zero (loop for i from 1 below n do (setf (aref vector i) (funcall op- (aref vector i) (aref vector 0))))) (loop for p from 2 below n when (= 1 (sbit sieve p)) do (loop for k from 1 below (ceiling n p) for pmult of-type fixnum = (* k p) do (setf (sbit sieve pmult) 0) (setf (aref vector k) (funcall op- (aref vector k) (aref vector pmult))))) vector)) ;;; ;;; (Slower) Zeta/Moebius transforms w.r.t. divisor or multiple in O(nlog(n)) time ;;; ;; (declaim (inline divisor-transform!)) ;; (defun divisor-transform! (vector &optional (op+ #'+) (handle-zero t)) ;; "Sets each VECTOR[i] to the sum of VECTOR[d] for all the divisors d of i in ;; O(nlog(n)) time." ;; (declare (vector vector)) ;; (let ((n (length vector))) ;; (when handle-zero ;; (loop for i from 1 below n ;; do (setf (aref vector 0) ;; (funcall op+ (aref vector 0) (aref vector i))))) ;; (loop for i from (- (ceiling n 2) 1) downto 1 ;; do (loop for j from (+ i i) below n by i ;; do (setf (aref vector j) ;; (funcall op+ (aref vector i) (aref vector j))))) ;; vector)) ;; (declaim (inline inverse-divisor-transform!)) ;; (defun inverse-divisor-transform! (vector &optional (op- #'-) (handle-zero t)) ;; "Does the inverse transform of DIVISOR-TRANSFORM! in O(nlog(n)) time." ;; (declare (vector vector)) ;; (let ((n (length vector))) ;; (loop for i from 1 below (ceiling n 2) ;; do (loop for j from (+ i i) below n by i ;; do (setf (aref vector j) ;; (funcall op- (aref vector j) (aref vector i))))) ;; (when handle-zero ;; (loop for i from 1 below n ;; do (setf (aref vector 0) ;; (funcall op- (aref vector 0) (aref vector i))))) ;; vector)) ;; (declaim (inline multiple-transform!)) ;; (defun multiple-transform! (vector &optional (op+ #'+) (handle-zero t)) ;; "Sets each VECTOR[i] to the sum of VECTOR[m] for all the multiples m of i in ;; O(nlog(n)) time. (To be precise, all the multiples smaller than the length of ;; VECTOR.)" ;; (declare (vector vector)) ;; (let ((n (length vector))) ;; (loop for i from 1 below (ceiling n 2) ;; do (loop for j from (+ i i) below n by i ;; do (setf (aref vector i) ;; (funcall op+ (aref vector i) (aref vector j))))) ;; (when handle-zero ;; (loop for i from 1 below n ;; do (setf (aref vector i) ;; (funcall op+ (aref vector 0) (aref vector i))))) ;; vector)) ;; (declaim (inline inverse-multiple-transform!)) ;; (defun inverse-multiple-transform! (vector &optional (op- #'-) (handle-zero t)) ;; "Does the inverse transform of MULTIPLE-TRANSFORM! in O(nlog(n)) time." ;; (declare (vector vector)) ;; (let ((n (length vector))) ;; (when handle-zero ;; (loop for i from 1 below n ;; do (setf (aref vector i) ;; (funcall op- (aref vector i) (aref vector 0))))) ;; (loop for i from (- (ceiling n 2) 1) downto 1 ;; do (loop for j from (+ i i) below n by i ;; do (setf (aref vector i) ;; (funcall op- (aref vector i) (aref vector j))))) ;; vector)) (defpackage :cp/tzcount (:use :cl) (:export #:tzcount)) (in-package :cp/tzcount) (declaim (inline tzcount)) (defun tzcount (x) "Returns the number of trailing zero bits of X. Note that (TZCOUNT 0) = -1." (- (integer-length (logand x (- x))) 1)) (defpackage :cp/mod-sqrt (:use :cl :cp/mod-power :cp/mod-inverse :cp/tzcount) (:export #:mod-sqrt) (:documentation "Provides Tonelli-Shanks algorithm for finding a modular square root.")) (in-package :cp/mod-sqrt) (defconstant +nbits+ 31) (deftype uint () '(integer 0 #.(- (ash 1 +nbits+) 1))) (declaim (inline mod-sqrt)) (defun mod-sqrt (a mod) "Returns a modular square root of A if it exists; otherwise returns NIL. MOD must be prime." (declare ((integer 0) a) ((and (integer 1) uint) mod)) (let ((a (mod a mod))) (when (or (< a 2) (= mod 2)) (return-from mod-sqrt a)) ;; Euler's criterion (unless (= 1 (mod-power a (ash (- mod 1) -1) mod)) (return-from mod-sqrt)) (let* ((b (loop for b = (+ 1 (random (- mod 1))) while (= 1 (mod-power b (ash (- mod 1) -1) mod)) finally (return b))) (init-shift (tzcount (- mod 1))) (q (ash (- mod 1) (- init-shift))) (x (mod-power a (ash (+ q 1) -1) mod)) (b (mod-power b q mod)) (/a (mod-inverse a mod)) (shift 2)) (declare ((mod #.+nbits+) shift init-shift) (uint b q x /a)) (loop until (= a (mod (* x x) mod)) for error = (mod (* /a (mod (* x x) mod)) mod) unless (= 1 (mod-power error (ash 1 (- init-shift shift)) mod)) do (setq x (mod (* x b) mod)) do (setq b (mod (* b b) mod)) (incf shift) finally (return x))))) (defpackage :cp/ntt (:use :cl :cp/mod-inverse :cp/barrett) (:export #:define-ntt #:check-ntt-vector #:ntt-int #:ntt-vector) (:documentation "Provides fast number theoretic transform. Reference: https://github.com/ei1333/library/blob/master/math/fft/number-theoretic-transform-friendly-mod-int.cpp https://github.com/atcoder/ac-library/tree/master/atcoder")) (in-package :cp/ntt) (deftype ntt-int () '(unsigned-byte 31)) (deftype ntt-vector () '(simple-array ntt-int (*))) (eval-when (:compile-toplevel :load-toplevel :execute) (declaim (inline %tzcount)) (defun %tzcount (x) "Returns the number of the trailing zero bits. Note that (%TZCOUNT 0) = -1." (- (integer-length (logand x (- x))) 1)) (defun %mod-power (base exp modulus) (declare (ntt-int base exp modulus)) (let ((res 1)) (declare (ntt-int res)) (loop while (> exp 0) when (oddp exp) do (setq res (mod (* res base) modulus)) do (setq base (mod (* base base) modulus) exp (ash exp -1))) res)) (defun %mod-inverse (x modulus) (%mod-power x (- modulus 2) modulus)) (defun %calc-generator (modulus) "MODULUS must be prime." (declare (ntt-int modulus)) (assert (>= modulus 2)) (case modulus (2 1) (167772161 3) (469762049 3) (754974721 11) (998244353 3) (otherwise (let ((divs (make-array 20 :element-type 'ntt-int :initial-element 0)) (end 1) (x (floor (- modulus 1) 2))) (declare ((integer 0 #.most-positive-fixnum) x)) (setf (aref divs 0) 2) (loop while (evenp x) do (setq x (floor x 2))) (loop for i of-type ntt-int from 3 by 2 while (<= (* i i) x) when (zerop (mod x i)) do (setf (aref divs end) i) (incf end) (loop while (zerop (mod x i)) do (setq x (floor x i)))) (when (> x 1) (setf (aref divs end) x) (incf end)) (loop for g of-type ntt-int from 2 do (dotimes (i end (return-from %calc-generator g)) (when (= 1 (%mod-power g (floor (- modulus 1) (aref divs i)) modulus)) (return))))))))) (declaim (ftype (function * (values ntt-vector &optional)) %adjust-array)) (defun %adjust-array (vector length) "This function always copies VECTOR. (ANSI CL doesn't state whether CL:ADJUST-ARRAY should copy the given array or not.)" (declare (optimize (speed 3)) (vector vector) ((mod #.array-dimension-limit) length)) (let ((vector (coerce vector 'ntt-vector))) (if (= (length vector) length) (copy-seq vector) (let ((res (make-array length :element-type 'ntt-int :initial-element 0))) (replace res vector) res)))) (defun check-ntt-vector (vector) (declare (optimize (speed 3)) (vector vector)) (let ((len (length vector))) (assert (and (zerop (logand len (- len 1))) ; power of two (typep len 'ntt-int))))) (defmacro define-ntt (modulus &key ntt inverse-ntt convolve &environment env) (assert (constantp modulus env)) (let* ((modulus #+sbcl (sb-int:constant-form-value modulus env) #-sbcl modulus) (ntt (or ntt (intern "NTT!"))) (inverse-ntt (or inverse-ntt (intern "INVERSE-NTT!"))) (convolve (or convolve (intern "CONVOLVE"))) (ntt-base (gensym "*NTT-BASE*")) (ntt-inv-base (gensym "*NTT-INV-BASE*")) (base-size (%tzcount (- modulus 1))) (root (%calc-generator modulus)) (modulus (sb-int:constant-form-value modulus env))) (declare (ntt-int modulus)) `(progn (declaim (ntt-vector ,ntt-base ,ntt-inv-base)) (sb-ext:define-load-time-global ,ntt-base (make-array ,base-size :element-type 'ntt-int)) (sb-ext:define-load-time-global ,ntt-inv-base (make-array ,base-size :element-type 'ntt-int)) (dotimes (i ,base-size) (setf (aref ,ntt-base i) (mod (- (%mod-power ,root (ash (- ,modulus 1) (- (+ i 2))) ,modulus)) ,modulus) (aref ,ntt-inv-base i) (%mod-inverse (aref ,ntt-base i) ,modulus))) (declaim (ftype (function * (values ntt-vector &optional)) ,ntt)) (defun ,ntt (vector) (declare (optimize (speed 3) (safety 0)) (vector vector)) (check-ntt-vector vector) (labels ((mod* (x y) (fast-mod (* x y) ,modulus)) (mod+ (x y) (%himod (+ x y) ,modulus)) (mod- (x y) (%lomod (- x y) ,modulus))) (declare (inline mod* mod+ mod-)) (let* ((vector (coerce vector 'ntt-vector)) (len (length vector)) (base ,ntt-base)) (declare (ntt-vector vector base) (ntt-int len)) (when (<= len 1) (return-from ,ntt vector)) (loop for m of-type ntt-int = (ash len -1) then (ash m -1) while (> m 0) for w of-type ntt-int = 1 for k of-type ntt-int = 0 do (loop for s of-type ntt-int from 0 below len by (* 2 m) do (loop for i from s below (+ s m) for j from (+ s m) for x = (aref vector i) for y = (mod* (aref vector j) w) do (setf (aref vector i) (mod+ x y) (aref vector j) (mod- x y))) (incf k) (setq w (mod* w (aref base (%tzcount k)))))) vector))) (declaim (ftype (function * (values ntt-vector &optional)) ,inverse-ntt)) (defun ,inverse-ntt (vector &optional inverse) (declare (optimize (speed 3) (safety 0)) (vector vector)) (check-ntt-vector vector) (labels ((mod* (x y) (fast-mod (* x y) ,modulus)) (mod+ (x y) (%himod (+ x y) ,modulus)) (mod- (x y) (%lomod (- x y) ,modulus))) (declare (inline mod* mod+ mod-)) (let* ((vector (coerce vector 'ntt-vector)) (len (length vector)) (base ,ntt-inv-base)) (declare (ntt-vector vector base) (ntt-int len)) (when (<= len 1) (return-from ,inverse-ntt vector)) (loop for m of-type ntt-int = 1 then (ash m 1) while (< m len) for w of-type ntt-int = 1 for k of-type ntt-int = 0 do (loop for s of-type ntt-int from 0 below len by (* 2 m) do (loop for i from s below (+ s m) for j from (+ s m) for x = (aref vector i) for y = (aref vector j) do (setf (aref vector i) (mod+ x y) (aref vector j) (mod* (mod- x y) w))) (incf k) (setq w (mod* w (aref base (%tzcount k)))))) (when inverse (let ((inv-len (mod-inverse len ,modulus))) (dotimes (i len) (setf (aref vector i) (mod* inv-len (aref vector i)))))) vector))) (declaim (ftype (function * (values ntt-vector &optional)) ,convolve)) (defun ,convolve (vector1 vector2) (declare (optimize (speed 3)) (vector vector1 vector2)) ;; TODO: if (EQ VECTOR1 VECTOR2) holds, the number of FFTs can be ;; reduced. (let* ((len1 (length vector1)) (len2 (length vector2)) (mul-len (max 0 (- (+ len1 len2) 1))) (vector1 (coerce vector1 'ntt-vector)) (vector2 (coerce vector2 'ntt-vector))) (declare (ntt-vector vector1 vector2) ((mod #.array-dimension-limit) mul-len)) (when (or (zerop len1) (zerop len2)) (return-from ,convolve (make-array 0 :element-type 'ntt-int))) ;; naive convolution (when (<= (min len1 len2) 32) (let ((res (make-array mul-len :element-type 'ntt-int :initial-element 0))) (declare (optimize (speed 3) (safety 0))) (dotimes (d mul-len) ;; 0 <= i <= deg1, 0 <= j <= deg2 (loop with coef of-type ntt-int = 0 for i from (max 0 (- d (- len2 1))) to (min d (- len1 1)) for j = (- d i) do (setq coef (fast-mod (+ coef (* (aref vector1 i) (aref vector2 j))) ,modulus)) finally (setf (aref res d) coef))) (return-from ,convolve res))) (let* (;; power of two ceiling (required-len (ash 1 (integer-length (max 0 (- mul-len 1))))) (vector1 (,ntt (%adjust-array vector1 required-len))) (vector2 (,ntt (%adjust-array vector2 required-len)))) (dotimes (i required-len) (setf (aref vector1 i) (fast-mod (* (aref vector1 i) (aref vector2 i)) ,modulus))) (subseq (,inverse-ntt vector1 t) 0 mul-len))))))) #+(or) (define-ntt +mod+) (defpackage :cp/fps (:use :cl :cp/ntt :cp/mod-inverse :cp/mod-power :cp/mod-sqrt :cp/static-mod) (:export #:poly-prod #:poly-inverse #:poly-floor #:poly-mod #:poly-sub #:poly-add #:multipoint-eval #:poly-total-prod #:chirp-z #:bostan-mori #:poly-differentiate! #:poly-integrate #:poly-log #:poly-exp #:poly-power #:poly-sqrt)) (in-package :cp/fps) ;; TODO: integrate with cp/polynomial (define-ntt +mod+ :convolve poly-prod) (declaim (inline %adjust)) (defun %adjust (vector size) (declare (ntt-vector vector)) (if (or (null size) (= size (length vector))) vector (let ((res (make-array size :element-type 'ntt-int :initial-element 0))) (replace res vector) res))) (declaim (inline %power-of-two-ceiling)) (defun %power-of-two-ceiling (x) (declare (ntt-int x)) (ash 1 (integer-length (- x 1)))) ;; (declaim (ftype (function * (values ntt-vector &optional)) poly-inverse)) ;; (defun poly-inverse (poly &optional result-length) ;; (declare (optimize (speed 3)) ;; (vector poly) ;; ((or null fixnum) result-length)) ;; (let* ((poly (coerce poly 'ntt-vector)) ;; (n (length poly))) ;; (declare (ntt-vector poly)) ;; (when (or (zerop n) ;; (zerop (aref poly 0))) ;; (error 'division-by-zero ;; :operation #'poly-inverse ;; :operands (list poly))) ;; (let ((res (make-array 1 ;; :element-type 'ntt-int ;; :initial-element (mod-inverse (aref poly 0) +mod+))) ;; (result-length (or result-length n))) ;; (declare (ntt-vector res)) ;; (loop for i of-type ntt-int = 1 then (ash i 1) ;; while (< i result-length) ;; for decr = (poly-prod (poly-prod res res) ;; (subseq poly 0 (min (length poly) (* 2 i)))) ;; for decr-len = (length decr) ;; do (setq res (%adjust res (* 2 i) :initial-element 0)) ;; (dotimes (j (* 2 i)) ;; (setf (aref res j) ;; (mod (the ntt-int ;; (+ (mod (* 2 (aref res j)) +mod+) ;; (if (>= j decr-len) 0 (- +mod+ (aref decr j))))) ;; +mod+)))) ;; (%adjust res result-length)))) ;; Reference: https://opt-cp.com/fps-fast-algorithms/ (declaim (ftype (function * (values ntt-vector &optional)) poly-inverse)) (defun poly-inverse (poly &optional result-length) (declare (optimize (speed 3)) (vector poly) ((or null fixnum) result-length)) (let* ((poly (coerce poly 'ntt-vector)) (n (length poly))) (declare (ntt-vector poly)) (when (or (zerop n) (zerop (aref poly 0))) (error 'division-by-zero :operation #'poly-inverse :operands (list poly))) (let* ((result-length (or result-length n)) (res (make-array 1 :element-type 'ntt-int :initial-element (mod-inverse (aref poly 0) +mod+)))) (declare (ntt-vector res)) (loop for i of-type ntt-int = 1 then (ash i 1) while (< i result-length) for f of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int :initial-element 0) for g of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int :initial-element 0) do (replace f poly :end2 (min n (* 2 i))) (replace g res) (ntt! f) (ntt! g) (dotimes (j (* 2 i)) (setf (aref f j) (mod (* (aref g j) (aref f j)) +mod+))) (inverse-ntt! f) (replace f f :start1 0 :end1 i :start2 i :end2 (* 2 i)) (fill f 0 :start i :end (* 2 i)) (ntt! f) (dotimes (j (* 2 i)) (setf (aref f j) (mod (* (aref g j) (aref f j)) +mod+))) (inverse-ntt! f) (let ((inv-len (mod-inverse (* 2 i) +mod+))) (setq inv-len (mod (* inv-len (- +mod+ inv-len)) +mod+)) (dotimes (j i) (setf (aref f j) (mod (* inv-len (aref f j)) +mod+))) (setq res (%adjust res (* 2 i))) (replace res f :start1 i))) (%adjust res result-length)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-floor)) (defun poly-floor (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let* ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector)) (len1 (+ 1 (or (position 0 poly1 :from-end t :test-not #'eql) -1))) (len2 (+ 1 (or (position 0 poly2 :from-end t :test-not #'eql) -1)))) (when (> len2 len1) (return-from poly-floor (make-array 0 :element-type 'ntt-int))) (setq poly1 (nreverse (subseq poly1 0 len1)) poly2 (nreverse (subseq poly2 0 len2))) (let* ((res-len (+ 1 (- len1 len2))) (res (%adjust (poly-prod poly1 (poly-inverse poly2 res-len)) res-len))) (nreverse res)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-sub)) (defun poly-sub (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let* ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector)) (len (max (length poly1) (length poly2))) (res (make-array len :element-type 'ntt-int :initial-element 0))) (replace res poly1) (dotimes (i (length poly2)) (let ((value (- (aref res i) (aref poly2 i)))) (setf (aref res i) (if (< value 0) (+ value +mod+) value)))) (let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1)))) (%adjust res end)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-add)) (defun poly-add (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let* ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector)) (len (max (length poly1) (length poly2))) (res (make-array len :element-type 'ntt-int :initial-element 0))) (replace res poly1) (dotimes (i (length poly2)) (setf (aref res i) (mod (+ (aref res i) (aref poly2 i)) +mod+))) (let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1)))) (%adjust res end)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-mod)) (defun poly-mod (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector))) (when (loop for x across poly1 always (zerop x)) (return-from poly-mod (make-array 0 :element-type 'ntt-int))) (let* ((res (poly-sub poly1 (poly-prod (poly-floor poly1 poly2) poly2))) (end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1)))) (subseq res 0 end)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-total-prod)) (defun poly-total-prod (polys) "Returns the total polynomial product: polys[0] * polys[1] * ... * polys[n-1]." (declare (vector polys)) (let* ((n (length polys)) (dp (make-array n :element-type t))) (declare ((mod #.array-dimension-limit) n)) (when (zerop n) (return-from poly-total-prod (make-array 1 :element-type 'ntt-int :initial-element 1))) (replace dp polys) (loop for width of-type (mod #.array-dimension-limit) = 1 then (ash width 1) while (< width n) do (loop for i of-type (mod #.array-dimension-limit) from 0 by (* width 2) while (< (+ i width) n) do (setf (aref dp i) (poly-prod (aref dp i) (aref dp (+ i width)))))) (coerce (the vector (aref dp 0)) 'ntt-vector))) (declaim (ftype (function * (values ntt-vector &optional)) multipoint-eval)) (defun multipoint-eval (poly points) "The length of POINTS must be a power of two." (declare (optimize (speed 3)) (vector poly points) #+sbcl (sb-ext:muffle-conditions style-warning)) (check-ntt-vector points) (let* ((poly (coerce poly 'ntt-vector)) (points (coerce points 'ntt-vector)) (len (length points)) (table (make-array (max 0 (- (* 2 len) 1)) :element-type 'ntt-vector)) (res (make-array len :element-type 'ntt-int))) (unless (zerop len) (labels ((%build (l r pos) (declare ((mod #.array-dimension-limit) l r pos)) (if (= (- r l) 1) (let ((lin (make-array 2 :element-type 'ntt-int))) (setf (aref lin 0) (- +mod+ (aref points l)) ;; NOTE: non-zero (aref lin 1) 1) (setf (aref table pos) lin)) (let ((mid (ash (+ l r) -1))) (%build l mid (+ 1 (* pos 2))) (%build mid r (+ 2 (* pos 2))) (setf (aref table pos) (poly-prod (aref table (+ 1 (* pos 2))) (aref table (+ 2 (* pos 2))))))))) (%build 0 len 0)) (labels ((%eval (poly l r pos) (declare ((mod #.array-dimension-limit) l r pos)) (if (= (- r l) 1) (let ((tmp (poly-mod poly (aref table pos)))) (setf (aref res l) (if (zerop (length tmp)) 0 (aref tmp 0)))) (let ((mid (ash (+ l r) -1))) (%eval (poly-mod poly (aref table (+ (* 2 pos) 1))) l mid (+ (* 2 pos) 1)) (%eval (poly-mod poly (aref table (+ (* 2 pos) 2))) mid r (+ (* 2 pos) 2)))))) (%eval poly 0 len 0))) res)) ;; not tested (declaim (ftype (function * (values ntt-vector &optional)) chirp-z)) (defun chirp-z (poly base length) "Does multipoint evaluation of POLY with powers of BASE: P(base^0), P(base^1), ..., P(base^(length-1)). BASE must be coprime with modulus. Time complexity is O((N+MOD)log(N+MOD)). Reference: https://codeforces.com/blog/entry/83532" (declare (optimize (speed 3)) (vector poly) (ntt-int base length)) (when (zerop (length poly)) (return-from chirp-z (make-array length :element-type 'ntt-int :initial-element 0))) (let* ((poly (coerce poly 'ntt-vector)) (binv (mod-inverse base +mod+)) (n (length poly)) (m (max length n)) (n+m (+ n m)) (cs (make-array n :element-type 'ntt-int :initial-element 0)) (ds (make-array n+m :element-type 'ntt-int :initial-element 0))) (declare (ntt-int n m n+m)) (dotimes (i n) (setf (aref cs i) (mod (* (aref poly (- n 1 i)) (mod-power binv (ash (* (- n 1 i) (- n 2 i)) -1) +mod+)) +mod+))) (dotimes (i n+m) (setf (aref ds i) (mod-power base (ash (* i (- i 1)) -1) +mod+))) (let ((result (subseq (poly-prod cs ds) (- n 1) (+ (- n 1) length)))) (dotimes (i length) (setf (aref result i) (mod (* (aref result i) (mod-power binv (ash (* i (- i 1)) -1) +mod+)) +mod+))) result))) (declaim (ftype (function * (values ntt-int &optional)) bostan-mori)) (defun bostan-mori (index num denom) "Returns [x^index](num(x)/denom(x)). Reference: https://arxiv.org/abs/2008.08822 https://qiita.com/ryuhe1/items/da5acbcce4ac1911f47 (Japanese)" (declare (optimize (speed 3)) (unsigned-byte index) (vector num denom)) (labels ((even (p) (let ((res (make-array (ceiling (length p) 2) :element-type 'ntt-int))) (dotimes (i (length res)) (setf (aref res i) (aref p (* 2 i)))) res)) (odd (p) (let ((res (make-array (floor (length p) 2) :element-type 'ntt-int))) (dotimes (i (length res)) (setf (aref res i) (aref p (+ 1 (* 2 i))))) res)) (negate (p) (let ((res (copy-seq p))) (loop for i from 1 below (length res) by 2 do (setf (aref res i) (if (zerop (aref res i)) 0 (- +mod+ (aref res i))))) res))) (let ((num (coerce num 'ntt-vector)) (denom (coerce denom 'ntt-vector))) (when (or (zerop (length denom)) (zerop (aref denom 0))) (error 'division-by-zero :operands (list num denom) :operation 'bostan-mori)) (loop while (> index 0) for denom- = (negate denom) for u = (poly-prod num denom-) when (evenp index) do (setq num (even u)) else do (setq num (odd u)) do (setq denom (even (poly-prod denom denom-)) index (ash index -1)) finally (return (rem (* (if (zerop (length num)) 0 (aref num 0)) (mod-inverse (aref denom 0) +mod+)) +mod+)))))) (declaim (inline poly-differentiate!)) (defun poly-differentiate! (p) "Returns the derivative of P." (declare (vector p)) (let ((p (coerce p 'ntt-vector))) (when (zerop (length p)) (return-from poly-differentiate! p)) (dotimes (i (- (length p) 1)) (declare (ntt-int i)) (setf (aref p i) (mod (* (aref p (+ i 1)) (+ i 1)) +mod+))) (let ((end (+ 1 (or (position 0 p :from-end t :end (- (length p) 1) :test-not #'eql) -1)))) (subseq p 0 end)))) (declaim (ntt-vector *inv*)) (defparameter *inv* (make-array 2 :element-type 'ntt-int :initial-contents '(0 1))) (defun fill-inv! (new-size) (declare (optimize (speed 3)) ((mod #.array-dimension-limit) new-size)) (let* ((old-size (length *inv*)) (new-size (%power-of-two-ceiling (max old-size new-size)))) (when (< old-size new-size) (loop with inv of-type ntt-vector = (%adjust *inv* new-size) for x from old-size below new-size do (setf (aref inv x) (- +mod+ (mod (* (aref inv (rem +mod+ x)) (floor +mod+ x)) +mod+))) finally (setq *inv* inv))))) (declaim (inline poly-integrate)) (defun poly-integrate (p) "Returns an indefinite integral of P. Assumes the integration constant to be zero." (declare (vector p)) (let* ((p (coerce p 'ntt-vector)) (n (length p))) (when (zerop n) (return-from poly-integrate (make-array 0 :element-type 'ntt-int))) (fill-inv! (+ n 1)) (let ((result (make-array (+ n 1) :element-type 'ntt-int :initial-element 0)) (inv *inv*)) (dotimes (i n) (setf (aref result (+ i 1)) (mod (* (aref p i) (aref inv (+ i 1))) +mod+))) result))) (declaim (ftype (function * (values ntt-vector &optional)) poly-log)) (defun poly-log (poly &optional result-length) (declare (optimize (speed 3)) (vector poly) ((or null (integer 1 (#.array-dimension-limit))) result-length)) (let* ((poly (coerce poly 'ntt-vector)) (result-length (or result-length (length poly)))) (assert (= 1 (aref poly 0))) (let ((res (poly-integrate (%adjust (poly-prod (poly-differentiate! (copy-seq poly)) (poly-inverse poly result-length)) (- result-length 1))))) (%adjust res result-length)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-exp)) (defun poly-exp (poly &optional result-length) (declare (optimize (speed 3)) (vector poly) ((or null (mod #.array-dimension-limit)) result-length)) (assert (or (zerop (length poly)) (zerop (aref poly 0)))) (let ((poly (coerce poly 'ntt-vector)) (result-length (or result-length (length poly))) (res (make-array 1 :element-type 'ntt-int :initial-element 1))) (loop until (>= (length res) result-length) for new-len of-type (mod #.array-dimension-limit) = (* 2 (length res)) for log = (poly-log res new-len) do (loop for i from 0 below (min new-len (length poly)) do (setf (aref log i) (mod (- (aref poly i) (aref log i)) +mod+))) (setf (aref log 0) (mod (+ 1 (aref log 0)) +mod+) res (%adjust (poly-prod res log) new-len))) (%adjust res result-length))) (declaim (ftype (function * (values ntt-vector &optional)) poly-power)) (defun poly-power (poly exp &optional result-length) (declare (optimize (speed 3)) (vector poly) ((integer 0 #.most-positive-fixnum) exp) ((or null (mod #.array-dimension-limit)) result-length)) (let* ((poly (coerce poly 'ntt-vector)) (result-length (or result-length (length poly))) (init-pos (position 0 poly :test-not #'eql))) (when (or (null init-pos) (zerop result-length) (> (* init-pos exp) result-length)) (return-from poly-power (make-array result-length :element-type 'ntt-int :initial-element 0))) (let ((tmp (subseq poly init-pos))) (let ((inv (mod-inverse (aref poly init-pos) +mod+))) (dotimes (i (length tmp)) (setf (aref tmp i) (mod (* (aref tmp i) inv) +mod+)))) (setq tmp (poly-log tmp result-length)) (let ((exp (mod exp +mod+))) (dotimes (i (length tmp)) (setf (aref tmp i) (mod (* (aref tmp i) exp) +mod+))) (setq tmp (poly-exp tmp result-length))) (let ((power (mod-power (aref poly init-pos) exp +mod+))) (dotimes (i (length tmp)) (setf (aref tmp i) (mod (* (aref tmp i) power) +mod+)))) (if (zerop init-pos) tmp (let ((res (make-array result-length :element-type 'ntt-int :initial-element 0))) (replace res tmp :start1 (min (length res) (* init-pos exp))) res))))) (defconstant +inv2+ (ash (+ +mod+ 1) -1)) (defconstant +inv-2+ (mod (- +inv2+) +mod+)) (declaim (ftype (function * (values (or null ntt-vector) &optional)) poly-sqrt %poly-sqrt)) ;; (defun %poly-sqrt (poly result-length) ;; (declare (optimize (speed 3)) ;; (ntt-vector poly) ;; ((mod #.array-dimension-limit) result-length)) ;; (let ((sqrt (mod-sqrt (aref poly 0) +mod+))) ;; (when sqrt ;; (assert (= (aref poly 0) (mod (* sqrt sqrt) +mod+))) ;; (let ((res (make-array 1 :element-type 'ntt-int :initial-element sqrt))) ;; (loop for len of-type (mod #.array-dimension-limit) = 1 then (ash len 1) ;; while (< len result-length) ;; do (let* ((next-len (ash len 1)) ;; (tmp (poly-prod (poly-inverse res next-len) ;; (%adjust poly next-len)))) ;; (declare ((mod #.array-dimension-limit) next-len)) ;; (setq res (%adjust res next-len)) ;; (dotimes (i next-len) ;; (setf (aref res i) ;; (mod (* +inv2+ (mod (+ (aref res i) (aref tmp i)) +mod+)) ;; +mod+))))) ;; (%adjust res result-length))))) (defun %poly-sqrt (poly result-length) (declare (optimize (speed 3)) (ntt-vector poly) ((mod #.array-dimension-limit) result-length)) (let ((sqrt (mod-sqrt (aref poly 0) +mod+))) (when sqrt (assert (and (not (zerop sqrt)) (= (aref poly 0) (mod (* sqrt sqrt) +mod+)))) (let ((res (make-array result-length :element-type 'ntt-int :initial-element 0)) (tt (make-array result-length :element-type 'ntt-int :initial-element 0)) (dp (make-array 1 :element-type 'ntt-int :initial-element sqrt))) (setf (aref res 0) sqrt (aref tt 0) (mod-inverse sqrt +mod+)) (loop for len of-type (mod #.array-dimension-limit) = 1 then (ash len 1) while (< len result-length) do (dotimes (i len) (setf (aref dp i) (mod (expt (aref dp i) 2) +mod+))) (inverse-ntt! dp t) (let* ((next-len (ash len 1)) (tmp (make-array next-len :element-type 'ntt-int :initial-element 0)) (tmp2 (make-array next-len :element-type 'ntt-int :initial-element 0))) (declare ((mod #.array-dimension-limit) next-len)) (replace tmp dp :start1 len) (dotimes (i (min (length poly) len)) (setf (aref tmp (+ i len)) (mod (- (aref tmp (+ i len)) (aref poly i)) +mod+))) (dotimes (i (min (- (length poly) len) len)) (setf (aref tmp (+ i len)) (mod (- (aref tmp (+ i len)) (aref poly (+ i len))) +mod+))) (ntt! tmp) (replace tmp2 tt :end1 len) (ntt! tmp2) (dotimes (i next-len) (setf (aref tmp i) (mod (* (aref tmp i) (aref tmp2 i)) +mod+))) (inverse-ntt! tmp t) (loop for i from len below (min next-len result-length) do (setf (aref res i) (mod (* (aref tmp i) +inv-2+) +mod+))) (when (>= next-len result-length) (return)) (setq dp (subseq res 0 next-len)) (ntt! dp) (dotimes (i next-len) (setf (aref tmp i) (mod (* (aref dp i) (aref tmp2 i)) +mod+))) (inverse-ntt! tmp t) (fill tmp 0 :end len) (ntt! tmp) (dotimes (i next-len) (setf (aref tmp i) (mod (* (aref tmp i) (aref tmp2 i)) +mod+))) (inverse-ntt! tmp t) (loop for i from len below next-len do (setf (aref tt i) (mod (- (aref tmp i)) +mod+))))) res)))) (defun poly-sqrt (poly &optional result-length) (declare (optimize (speed 3)) (vector poly) ((or null (mod #.array-dimension-limit)) result-length)) (let* ((result-length (or result-length (length poly))) (poly (coerce poly 'ntt-vector))) (labels ((return-zero () (return-from poly-sqrt (make-array result-length :element-type 'ntt-int :initial-element 0)))) (when (or (zerop (length poly)) (zerop result-length)) (return-zero)) (if (zerop (aref poly 0)) (let ((i (position 0 poly :test-not #'=))) (unless i (return-zero)) (when (oddp i) (return-from poly-sqrt)) (when (<= result-length (ash i -1)) (return-zero)) (let ((tmp-res (%poly-sqrt (subseq poly i) (- result-length (ash i -1))))) (unless tmp-res (return-from poly-sqrt)) (let ((res (make-array result-length :element-type 'ntt-int :initial-element 0))) (replace res tmp-res :start1 (ash i -1)) res))) (%poly-sqrt poly result-length))))) (defpackage :cp/binom-mod-prime (:use :cl :cp/static-mod) (:export #:binom #:perm #:multinomial #:stirling2 #:catalan #:multichoose #:*fact* #:*fact-inv* #:*inv*) (:documentation "Provides tables of factorials, inverses, inverses ot factorials etc. modulo prime. build: O(n) query: O(1) ")) (in-package :cp/binom-mod-prime) ;; TODO: non-global handling (defconstant +binom-size+ 510000) (declaim ((simple-array (unsigned-byte 31) (*)) *fact* *fact-inv* *inv*)) (sb-ext:define-load-time-global *fact* (make-array +binom-size+ :element-type '(unsigned-byte 31)) "table of factorials") (sb-ext:define-load-time-global *fact-inv* (make-array +binom-size+ :element-type '(unsigned-byte 31)) "table of inverses of factorials") (sb-ext:define-load-time-global *inv* (make-array +binom-size+ :element-type '(unsigned-byte 31)) "table of inverses of non-negative integers") (defun initialize-binom () (declare (optimize (speed 3) (safety 0))) (setf (aref *fact* 0) 1 (aref *fact* 1) 1 (aref *fact-inv* 0) 1 (aref *fact-inv* 1) 1 (aref *inv* 1) 1) (loop for i from 2 below +binom-size+ do (setf (aref *fact* i) (mod (* i (aref *fact* (- i 1))) +mod+) (aref *inv* i) (- +mod+ (mod (* (aref *inv* (rem +mod+ i)) (floor +mod+ i)) +mod+)) (aref *fact-inv* i) (mod (* (aref *inv* i) (aref *fact-inv* (- i 1))) +mod+)))) (initialize-binom) (declaim (inline binom)) (defun binom (n k) "Returns nCk, the number of k-combinations of n things without repetition." (if (or (< n k) (< n 0) (< k 0)) 0 (mod (* (aref *fact* n) (mod (* (aref *fact-inv* k) (aref *fact-inv* (- n k))) +mod+)) +mod+))) (declaim (inline perm)) (defun perm (n k) "Returns nPk, the number of k-permutations of n things without repetition." (if (or (< n k) (< n 0) (< k 0)) 0 (mod (* (aref *fact* n) (aref *fact-inv* (- n k))) +mod+))) (declaim (inline multichoose)) (defun multichoose (n k) "Returns the number of k-combinations of n things with repetition." (binom (+ n k -1) k)) (declaim (inline multinomial)) (defun multinomial (&rest ks) "Returns the multinomial coefficient K!/k_1!k_2!...k_n! for K = k_1 + k_2 + ... + k_n. K must be equal to or smaller than MOST-POSITIVE-FIXNUM. (multinomial) returns 1." (let ((sum 0) (result 1)) (declare ((integer 0 #.most-positive-fixnum) result sum)) (dolist (k ks) (incf sum k) (setq result (mod (* result (aref *fact-inv* k)) +mod+))) (mod (* result (aref *fact* sum)) +mod+))) (define-compiler-macro multinomial (&rest args) (case (length args) ((0 1) (mod 1 +mod+)) (otherwise `(mod (* ,(reduce (lambda (x y) `(mod (* ,x ,y) +mod+)) args :key (lambda (x) `(aref *fact-inv* ,x))) (aref *fact* (+ ,@args))) +mod+)))) (declaim (inline stirling2)) (defun stirling2 (n k) "Returns the stirling number of the second kind S2(n, k). Time complexity is O(klog(n))." (declare ((integer 0 #.most-positive-fixnum) n k)) (labels ((mod-power (base exp) (declare ((integer 0 #.most-positive-fixnum) base exp)) (loop with res of-type (integer 0 #.most-positive-fixnum) = 1 while (> exp 0) when (oddp exp) do (setq res (mod (* res base) +mod+)) do (setq base (mod (* base base) +mod+) exp (ash exp -1)) finally (return res)))) (loop with result of-type fixnum = 0 for i from 0 to k for delta = (mod (* (binom k i) (mod-power i n)) +mod+) when (evenp (- k i)) do (incf result delta) (when (>= result +mod+) (decf result +mod+)) else do (decf result delta) (when (< result 0) (incf result +mod+)) finally (return (mod (* result (aref *fact-inv* k)) +mod+))))) (declaim (inline catalan)) (defun catalan (n) "Returns the N-th Catalan number." (declare ((integer 0 #.most-positive-fixnum) n)) (mod (* (aref *fact* (* 2 n)) (mod (* (aref *fact-inv* (+ n 1)) (aref *fact-inv* n)) +mod+)) +mod+)) (defpackage :cp/unlabeled-counting (:use :cl :cp/binom-mod-prime :cp/fps :cp/ntt :cp/static-mod :cp/zeta-integer) (:export #:unlabeled-deck-to-hand #:unlabeled-hand-to-deck)) (in-package :cp/unlabeled-counting) (declaim (inline %power-of-two-ceiling)) (defun %power-of-two-ceiling (x) (declare (ntt-int x)) (ash 1 (integer-length (- x 1)))) (declaim (ftype (function * (values ntt-vector &optional)) unlabeled-deck-to-hand)) (defun unlabeled-deck-to-hand (deck) "Converts an unlabeled deck enumerator to the (1-variable) hand enumerator. NOTE: - hand[0] = 1. - deck[0] is ignored and can be any number." (declare #.cl-user::*opt* (vector deck)) (let* ((deck (coerce deck 'ntt-vector)) (n (length deck)) (dp (make-array n :element-type 'ntt-int :initial-element 0))) (declare (ntt-int n)) (loop for i from 1 below n for d = (aref deck i) do (loop for j from 1 to (floor (- n 1) i) for pos of-type (mod #.array-dimension-limit) = (* i j) do (setf (aref dp pos) (mod (+ (* d (aref *inv* j)) (aref dp pos)) +mod+)))) (poly-exp dp))) (declaim (ftype (function * (values ntt-vector &optional)) unlabeled-hand-to-deck)) (defun unlabeled-hand-to-deck (hand) "Converts an unlabeld hand enumerator to the deck enumerator. NOTE: - hand[0] must be 1. - deck[0] = 0." (declare (optimize (speed 3)) (vector hand)) (let* ((hand (coerce hand 'ntt-vector)) (n (length hand)) (dp (make-array n :element-type 'ntt-int :initial-element 0))) (declare (ntt-int n)) (when (> n 0) (assert (= (mod 1 +mod+) (aref hand 0)))) (dotimes (i n) (setf (aref dp i) (mod (* (aref hand i) i) +mod+))) (labels ((recur (l r) (declare ((mod #.array-dimension-limit) l r)) (when (< (+ l 1) r) (let ((mid (ash (+ l r) -1))) (recur l mid) (let* ((poly1 (subseq hand 0 (- r l))) (poly2 (subseq dp l mid)) (prod (poly-prod poly1 poly2))) (loop for i from mid below r do (setf (aref dp i) (mod (- (aref dp i) (aref prod (- i l))) +mod+)))) (recur mid r))))) (recur 0 n)) (inverse-divisor-transform! dp (lambda (x y) (declare (ntt-int x y)) (mod (- x y) +mod+)) nil) (loop for i from 1 below n do (setf (aref dp i) (mod (* (aref dp i) (aref *inv* i)) +mod+))) dp)) (defpackage :cp/integer-log (:use :cl) (:export #:log2-ceil #:log-ceil #:log10-floor #:decimal-length)) (in-package :cp/integer-log) (declaim (inline log2-ceil)) (defun log2-ceil (x) "Rounds up log2(x). Special case: (log2-ceil 0) = 0" (declare ((real 0) x)) (integer-length (- (ceiling x) 1))) (declaim (inline log-ceil)) (defun log-ceil (x base) "Rounds up log(x). Signals DIVISION-BY-ZERO if X is zero." (declare (real x) ((integer 2) base)) (when (zerop x) (error 'division-by-zero :operands (list 0 base) :operation 'log-ceil)) (if (integerp x) (loop for i from 0 for y = x then (ceiling y base) when (<= y 1) do (return i)) (nth-value 0 (ceiling (log x base))))) (defconstant +word-bits+ 62) (deftype uint () '(unsigned-byte #.+word-bits+)) (declaim ((simple-array (unsigned-byte 8) (*)) *lo*)) (sb-ext:define-load-time-global *lo* (let ((tmp (make-array (+ 1 +word-bits+) :element-type '(unsigned-byte 8)))) (dotimes (exp (length tmp) tmp) (setf (aref tmp exp) (- (length (write-to-string (ash 1 (- exp 1)))) 1))))) (declaim ((simple-array uint (*)) *hi-power10*)) (sb-ext:define-load-time-global *hi-power10* (let* ((max (reduce #'max *lo*)) (tmp (make-array (+ max 1) :element-type 'uint))) (dotimes (i (length tmp) tmp) (setf (aref tmp i) (min (- (expt 10 (+ i 1)) 1) (- (ash 1 +word-bits+) 1)))))) (declaim (inline log10-floor)) (defun log10-floor (x) "Returns floor(log_{10}(x)). Special case: (log10-floor 0) == 0." (declare (uint x)) (let ((lo (aref *lo* (integer-length x)))) (+ lo (if (> x (aref *hi-power10* lo)) 1 0)))) (declaim (inline decimal-length)) (defun decimal-length (x) "Special case (decimal-length 0) == 1." (declare (uint x)) (+ 1 (log10-floor x))) (defpackage :cp/decimal-sequence (:use :cl :cp/integer-log) (:export #:decimal-seq-to-str)) (in-package :cp/decimal-sequence) (declaim (inline decimal-seq-to-str)) (defun decimal-seq-to-str (seq &optional (separator #\ )) (let ((length 0)) (declare ((mod #.array-dimension-limit) length)) (sb-sequence:dosequence (x seq) (incf length (+ 1 (decimal-length x)))) ;; delete the last separator (when (> length 0) (decf length)) (let ((res (make-string length :element-type 'base-char :initial-element separator)) (end 0)) (declare ((mod #.array-dimension-limit) end)) (sb-sequence:dosequence (x seq) (let ((x x) (width (decimal-length x))) (declare ((integer 0 #.most-positive-fixnum) x)) (if (zerop x) (setf (aref res end) #\0) (loop for i from (- width 1) downto 0 do (multiple-value-bind (quot rem) (floor x 10) (setf (aref res (+ end i)) (code-char (+ 48 rem)) x quot)))) (incf end (+ width 1)))) res))) ;; BEGIN_USE_PACKAGE (eval-when (:compile-toplevel :load-toplevel :execute) (use-package :cp/decimal-sequence :cl-user)) (eval-when (:compile-toplevel :load-toplevel :execute) (use-package :cp/unlabeled-counting :cl-user)) (in-package :cl-user) ;;; ;;; Body ;;; (defun make-partition (n) (declare (optimize (speed 3)) (uint31 n)) (let ((res (make-array n :element-type 'uint31 :initial-element 0)) (tt (isqrt (+ (* 24 (- n 1)) 1)))) (loop for i from (+ (floor (- tt) 6) 1) to (floor (+ 1 tt) 6) for k = (ash (* i (- (* 3 i) 1)) -1) do (setf (aref res k) (mod (- 1 (* 2 (logand i 1))) +mod+))) (cp/fps:poly-inverse res))) (defun main () (declare #.*opt*) (let* ((n (read)) (res (make-partition (+ n 1)))) (declare (uint31 n)) (write-line (decimal-seq-to-str res)))) #-swank (main) ;;; ;;; Test ;;; #+swank (progn (defparameter *lisp-file-pathname* (uiop:current-lisp-file-pathname)) (setq *default-pathname-defaults* (uiop:pathname-directory-pathname *lisp-file-pathname*)) (uiop:chdir *default-pathname-defaults*) (defparameter *dat-pathname* (uiop:merge-pathnames* "test.dat" *lisp-file-pathname*)) (defparameter *problem-url* "PROBLEM_URL_TO_BE_REPLACED")) #+swank (defun gen-dat () (uiop:with-output-file (out *dat-pathname* :if-exists :supersede) (format out ""))) #+swank (defun bench (&optional (out (make-broadcast-stream))) (time (run *dat-pathname* out))) #+(and sbcl (not swank)) (eval-when (:compile-toplevel) (when sb-c::*undefined-warnings* (error "undefined warnings: ~{~A~^ ~}" sb-c::*undefined-warnings*)))