Submit Info #62887

Problem Lang User Status Time Memory
Multipoint Evaluation lisp sansaqua AC 1900 ms 83.89 MiB

ケース詳細
Name Status Time Memory
example_00 AC 14 ms 2.63 MiB
example_01 AC 14 ms 2.66 MiB
max_random_00 AC 1900 ms 83.88 MiB
max_random_01 AC 1895 ms 83.89 MiB
random_00 AC 555 ms 63.02 MiB
random_01 AC 473 ms 62.30 MiB
random_02 AC 1747 ms 83.40 MiB
zero_00 AC 14 ms 2.63 MiB

#-swank (unless (member :child-sbcl *features*) (quit :unix-status (process-exit-code (run-program *runtime-pathname* `("--control-stack-size" "256MB" "--noinform" "--disable-ldb" "--lose-on-corruption" "--end-runtime-options" "--eval" "(push :child-sbcl *features*)" "--script" ,(namestring *load-pathname*)) :output t :error t :input t)))) (in-package :cl-user) (eval-when (:compile-toplevel :load-toplevel :execute) (defparameter *opt* #+swank '(optimize (speed 3) (safety 2)) #-swank '(optimize (speed 3) (safety 0) (debug 0))) #+swank (ql:quickload '(:cl-debug-print :fiveam :cp/util) :silent t) #+swank (use-package :cp/util :cl-user) #-swank (set-dispatch-macro-character #\# #\> (lambda (s c p) (declare (ignore c p)) `(values ,(read s nil nil t)))) #+sbcl (dolist (f '(:popcnt :sse4)) (pushnew f sb-c:*backend-subfeatures*)) #+sbcl (setq *random-state* (seed-random-state (nth-value 1 (get-time-of-day))))) #-swank (eval-when (:compile-toplevel) (setq *break-on-signals* '(and warning (not style-warning)))) #+swank (set-dispatch-macro-character #\# #\> #'cl-debug-print:debug-print-reader) (macrolet ((def (b) `(progn (deftype ,(intern (format nil "UINT~A" b)) () '(unsigned-byte ,b)) (deftype ,(intern (format nil "INT~A" b)) () '(signed-byte ,b)))) (define-int-types (&rest bits) `(progn ,@(mapcar (lambda (b) `(def ,b)) bits)))) (define-int-types 2 4 7 8 15 16 31 32 62 63 64)) (defconstant +mod+ 1000000007) (defmacro dbg (&rest forms) (declare (ignorable forms)) #+swank (if (= (length forms) 1) `(format *error-output* "~A => ~A~%" ',(car forms) ,(car forms)) `(format *error-output* "~A => ~A~%" ',forms `(,,@forms)))) (declaim (inline println)) (defun println (obj &optional (stream *standard-output*)) (let ((*read-default-float-format* (if (typep obj 'double-float) 'double-float *read-default-float-format*))) (prog1 (princ obj stream) (terpri stream)))) ;; BEGIN_INSERTED_CONTENTS (defpackage :cp/experimental/barrett (:use :cl) (:import-from #:sb-ext #:truly-the) (:import-from #:sb-c #:deftransform #:derive-type #:defoptimizer #:defknown #:movable #:foldable #:flushable #:commutative #:always-translatable #:lvar-type #:lvar-value #:give-up-ir1-transform #:integer-type-numeric-bounds #:define-vop) (:import-from #:sb-vm #:move #:inst #:eax-offset #:edx-offset #:any-reg #:control-stack #:unsigned-reg #:positive-fixnum #:fixnumize #:ea) (:import-from #:sb-kernel #:specifier-type) (:import-from #:sb-int #:explicit-check #:constant-arg) (:export #:fast-mod #:%himod #:%lomod) (:documentation "Provides Barrett reduction.")) (in-package :cp/experimental/barrett) (eval-when (:compile-toplevel :load-toplevel :execute) (defun derive-* (x y) (let ((high1 (nth-value 1 (integer-type-numeric-bounds (lvar-type x)))) (high2 (nth-value 1 (integer-type-numeric-bounds (lvar-type y))))) (specifier-type (if (and (integerp high1) (integerp high2)) `(integer 0 ,(ash (* high1 high2) -62)) `(integer 0))))) (defun derive-mod (modulus) (let ((high (nth-value 1 (integer-type-numeric-bounds (lvar-type modulus))))) (specifier-type (if (integerp high) `(integer 0 (,high)) `(integer 0))))) (defun gpr-tn-p (x) (declare (ignorable x)) #.(if (find-symbol "GPR-TN-P" :sb-vm) `(funcall (intern "GPR-TN-P" :sb-vm) x) t))) (eval-when (:compile-toplevel :load-toplevel :execute) ;; *-high62 (defknown *-high62 ((unsigned-byte 62) (unsigned-byte 62)) (unsigned-byte 62) (movable foldable flushable commutative sb-c::always-translatable) :overwrite-fndb-silently t) (defoptimizer (*-high62 derive-type) ((x y)) (derive-* x y)) (define-vop (fast-*-high62/fixnum) (:translate *-high62) (:policy :fast-safe) (:args (x :scs (any-reg) :target eax) (y :scs (any-reg control-stack))) (:arg-types positive-fixnum positive-fixnum) (:temporary (:sc any-reg :offset eax-offset :from (:argument 0) :to :result) eax) (:temporary (:sc any-reg :offset edx-offset :target r :from :eval :to :result) edx) (:results (r :scs (any-reg))) (:result-types positive-fixnum) (:note "inline *-high62") (:vop-var vop) (:save-p :compute-only) (:generator 6 (move eax x) (inst mul eax y) (inst shl edx 1) (move r edx))) (define-vop (fast-c-*-high62-/fixnum) (:translate *-high62) (:policy :fast-safe) (:args (x :scs (any-reg) :target eax)) (:info y) (:arg-types positive-fixnum (:constant (unsigned-byte 62))) (:temporary (:sc any-reg :offset eax-offset :from (:argument 0) :to :result) eax) (:temporary (:sc any-reg :offset edx-offset :target r :from :eval :to :result) edx) (:results (r :scs (any-reg))) (:result-types positive-fixnum) (:note "inline constant *-high62") (:vop-var vop) (:save-p :compute-only) (:generator 6 (move eax x) (inst mul eax (sb-c:register-inline-constant :qword y)) (inst shl edx 1) (move r edx))) (defun *-high62 (x y) (declare (explicit-check)) (*-high62 x y))) ;; %himod (eval-when (:compile-toplevel :load-toplevel :execute) (defknown %himod ((unsigned-byte 32) (unsigned-byte 31)) (unsigned-byte 31) (movable foldable flushable sb-c:always-translatable) :overwrite-fndb-silently t) (defoptimizer (%himod derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus)) (define-vop (fast-c-%himod) (:translate %himod) (:policy :fast-safe) (:args (x :scs (any-reg) :target r)) (:info m) (:arg-types positive-fixnum (:constant (unsigned-byte 31))) (:temporary (:sc any-reg :from :eval :to :result) y) (:results (r :scs (any-reg))) (:result-types positive-fixnum) (:note "inline constant %himod") (:vop-var vop) (:generator 4 ;; maybe verbose (assert (gpr-tn-p x)) (when (sb-c:tn-p m) (assert (sb-c:sc-is m sb-vm::immediate)) (setq m (sb-c::tn-value m))) (setq m (fixnumize m)) (move r x) (inst cmp r (- m 2)) (inst lea y #.(if (fboundp 'ea) `(sb-vm::ea (- m) r) `(,(find-symbol "MAKE-EA" :sb-vm) :dword :disp (- m) :base r))) (inst cmov :a r y)))) ;; %lomod (eval-when (:compile-toplevel :load-toplevel :execute) (defknown %lomod ((signed-byte 32) (unsigned-byte 31)) (unsigned-byte 31) (movable foldable flushable always-translatable) :overwrite-fndb-silently t) (defoptimizer (%lomod derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus)) (define-vop (fast-c-%lomod) (:translate %lomod) (:policy :fast-safe) (:args (x :scs (any-reg) :target r)) (:info m) (:arg-types fixnum (:constant (unsigned-byte 31))) (:temporary (:sc any-reg :from :eval :to :result) y) (:results (r :scs (any-reg))) (:result-types positive-fixnum) (:note "inline constant %himod") (:vop-var vop) (:generator 4 (assert (gpr-tn-p x)) (when (sb-c:tn-p m) (assert (sb-c:sc-is m sb-vm::immediate)) (setq m (sb-c::tn-value m))) (setq m (fixnumize m)) (move r x) (inst or r r) (inst lea y #.(if (fboundp 'ea) `(sb-vm::ea m r) `(,(find-symbol "MAKE-EA" :sb-vm) :dword :disp m :base r))) (inst cmov :l r y)))) ;; fast-mod (eval-when (:compile-toplevel :load-toplevel :execute) (defknown fast-mod (integer unsigned-byte) unsigned-byte (movable foldable flushable) :overwrite-fndb-silently t) (defoptimizer (fast-mod derive-type) ((integer modulus)) (declare (ignore integer)) (derive-mod modulus)) (defun fast-mod (number divisor) "Is equivalent to CL:MOD for integer arguments" (declare (explicit-check)) (mod number divisor)) (deftransform fast-mod ((number divisor) ((unsigned-byte 62) (constant-arg (unsigned-byte 31))) * :important t) "convert FAST-MOD to barrett reduction" (let ((mod (lvar-value divisor))) (if (<= mod 1) (give-up-ir1-transform) (let ((m (floor (ash 1 62) mod))) `(let* ((q (*-high62 number ,m)) (x (truly-the (signed-byte 32) (- number (* q ,mod))))) (if (< x ,mod) x (- x ,mod)) ;; Not so effective because branch prediction is usually correct? ;; (sb-ext:truly-the (mod ,mod) (%himod x ,mod)) )))))) (defpackage :cp/read-fixnum (:use :cl) (:export #:read-fixnum)) (in-package :cp/read-fixnum) (declaim (ftype (function * (values fixnum &optional)) read-fixnum)) (defun read-fixnum (&optional (in *standard-input*)) "NOTE: cannot read -2^62" (macrolet ((%read-byte () `(the (unsigned-byte 8) #+swank (char-code (read-char in nil #\Nul)) #-swank (sb-impl::ansi-stream-read-byte in nil #.(char-code #\Nul) nil)))) (let* ((minus nil) (result (loop (let ((byte (%read-byte))) (cond ((<= 48 byte 57) (return (- byte 48))) ((zerop byte) ; #\Nul (error "Read EOF or #\Nul.")) ((= byte #.(char-code #\-)) (setq minus t))))))) (declare ((integer 0 #.most-positive-fixnum) result)) (loop (let* ((byte (%read-byte))) (if (<= 48 byte 57) (setq result (+ (- byte 48) (* 10 (the (integer 0 #.(floor most-positive-fixnum 10)) result)))) (return (if minus (- result) result)))))))) ;;; ;;; Fast Number Theoretic Transform ;;; Reference: ;;; https://github.com/ei1333/library/blob/master/math/fft/number-theoretic-transform-friendly-mod-int.cpp ;;; https://github.com/atcoder/ac-library/tree/master/atcoder ;;; (defpackage :cp/ntt (:use :cl) (:import-from :cp/experimental/barrett #:fast-mod #:%himod #:%lomod) (:export #:define-ntt #:check-ntt-vector #:ntt-int #:ntt-vector #:+ntt-mod+)) (in-package :cp/ntt) (deftype ntt-int () '(unsigned-byte 31)) (deftype ntt-vector () '(simple-array ntt-int (*))) (eval-when (:compile-toplevel :load-toplevel :execute) (declaim (inline %tzcount)) (defun %tzcount (x) "Returns the number of the trailing zero bits. Note that (%TZCOUNT 0) = -1." (- (integer-length (logand x (- x))) 1)) (defun %mod-power (base exp modulus) (declare (ntt-int base exp modulus)) (let ((res 1)) (declare (ntt-int res)) (loop while (> exp 0) when (oddp exp) do (setq res (fast-mod (* res base) modulus)) do (setq base (fast-mod (* base base) modulus) exp (ash exp -1))) res)) (defun %mod-inverse (x modulus) (%mod-power x (- modulus 2) modulus)) (defun %calc-generator (modulus) "MODULUS must be prime." (declare (ntt-int modulus)) (assert (>= modulus 2)) (case modulus (2 1) (167772161 3) (469762049 3) (754974721 11) (998244353 3) (otherwise (let ((divs (make-array 20 :element-type 'ntt-int :initial-element 0)) (end 1) (x (floor (- modulus 1) 2))) (declare ((integer 0 #.most-positive-fixnum) x)) (setf (aref divs 0) 2) (loop while (evenp x) do (setq x (floor x 2))) (loop for i of-type ntt-int from 3 by 2 while (<= (* i i) x) when (zerop (mod x i)) do (setf (aref divs end) i) (incf end) (loop while (zerop (mod x i)) do (setq x (floor x i)))) (when (> x 1) (setf (aref divs end) x) (incf end)) (loop for g of-type ntt-int from 2 do (dotimes (i end (return-from %calc-generator g)) (when (= 1 (%mod-power g (floor (- modulus 1) (aref divs i)) modulus)) (return))))))))) ;; KLUDGE: This function depends on SBCL's behaviour. That is, ADJUST-ARRAY ;; isn't guaranteed to preserve a given VECTOR in ANSI CL. (declaim (ftype (function * (values ntt-vector &optional)) %adjust-array)) (defun %adjust-array (vector length) (declare (vector vector)) (let ((vector (coerce vector 'ntt-vector))) (if (= (length vector) length) (copy-seq vector) (adjust-array vector length :initial-element 0)))) (defun check-ntt-vector (vector) (declare (optimize (speed 3)) (vector vector)) (let ((len (length vector))) (assert (zerop (logand len (- len 1)))) ;; power of two (check-type len ntt-int))) (defmacro define-ntt (modulus &key ntt inverse-ntt convolve mod-inverse mod-power &environment env) (assert (constantp modulus env)) (let* ((modulus #+sbcl (sb-int:constant-form-value modulus env) #-sbcl modulus) (ntt (or ntt (intern "NTT!"))) (inverse-ntt (or inverse-ntt (intern "INVERSE-NTT!"))) (convolve (or convolve (intern "CONVOLVE"))) (mod-power (or mod-power (gensym "MOD-POWER"))) (mod-inverse (or mod-inverse (gensym "MOD-INVERSE"))) (ntt-base (gensym "*NTT-BASE*")) (ntt-inv-base (gensym "*NTT-INV-BASE*")) (base-size (%tzcount (- modulus 1))) (root (%calc-generator modulus)) (modulus (sb-int:constant-form-value modulus env))) (declare (ntt-int modulus)) `(progn (declaim (inline ,mod-power)) (defun ,mod-power (base exp) (declare (ntt-int base) ((integer 0 #.most-positive-fixnum) exp)) (let ((res 1)) (declare (ntt-int res)) (loop while (> exp 0) when (oddp exp) do (setq res (fast-mod (* res base) ,modulus)) do (setq base (fast-mod (* base base) ,modulus) exp (ash exp -1))) res)) (declaim (inline ,mod-inverse)) (defun ,mod-inverse (x) (,mod-power x (- ,modulus 2))) (declaim (ntt-vector ,ntt-base ,ntt-inv-base)) (sb-ext:define-load-time-global ,ntt-base (make-array ,base-size :element-type 'ntt-int)) (sb-ext:define-load-time-global ,ntt-inv-base (make-array ,base-size :element-type 'ntt-int)) (dotimes (i ,base-size) (setf (aref ,ntt-base i) (mod (- (%mod-power ,root (ash (- ,modulus 1) (- (+ i 2))) ,modulus)) ,modulus) (aref ,ntt-inv-base i) (%mod-inverse (aref ,ntt-base i) ,modulus))) (declaim (ftype (function * (values ntt-vector &optional)) ,ntt)) (defun ,ntt (vector) (declare (optimize (speed 3) (safety 0)) (vector vector)) (check-ntt-vector vector) (labels ((mod* (x y) (fast-mod (* x y) ,modulus)) (mod+ (x y) (%himod (+ x y) ,modulus)) (mod- (x y) (%lomod (- x y) ,modulus))) (declare (inline mod* mod+ mod-)) (let* ((vector (coerce vector 'ntt-vector)) (len (length vector)) (base ,ntt-base)) (declare (ntt-vector vector base) (ntt-int len)) (when (<= len 1) (return-from ,ntt vector)) (loop for m of-type ntt-int = (ash len -1) then (ash m -1) while (> m 0) for w of-type ntt-int = 1 for k of-type ntt-int = 0 do (loop for s of-type ntt-int from 0 below len by (* 2 m) do (loop for i from s below (+ s m) for j from (+ s m) for x = (aref vector i) for y = (mod* (aref vector j) w) do (setf (aref vector i) (mod+ x y) (aref vector j) (mod- x y))) (incf k) (setq w (mod* w (aref base (%tzcount k)))))) vector))) (declaim (ftype (function * (values ntt-vector &optional)) ,inverse-ntt)) (defun ,inverse-ntt (vector &optional inverse) (declare (optimize (speed 3) (safety 0)) (vector vector)) (check-ntt-vector vector) (labels ((mod* (x y) (declare (ntt-int x y)) (fast-mod (* x y) ,modulus)) (mod+ (x y) (%himod (+ x y) ,modulus)) (mod- (x y) (%lomod (- x y) ,modulus))) (declare (inline mod* mod+ mod-)) (let* ((vector (coerce vector 'ntt-vector)) (len (length vector)) (base ,ntt-inv-base)) (declare (ntt-vector vector base) (ntt-int len)) (when (<= len 1) (return-from ,inverse-ntt vector)) (loop for m of-type ntt-int = 1 then (ash m 1) while (< m len) for w of-type ntt-int = 1 for k of-type ntt-int = 0 do (loop for s of-type ntt-int from 0 below len by (* 2 m) do (loop for i from s below (+ s m) for j from (+ s m) for x = (aref vector i) for y = (aref vector j) do (setf (aref vector i) (mod+ x y) (aref vector j) (mod* (mod- x y) w))) (incf k) (setq w (mod* w (aref base (%tzcount k)))))) (when inverse (let ((inv-len (,mod-power len (- ,modulus 2)))) (dotimes (i len) (setf (aref vector i) (mod* inv-len (aref vector i)))))) vector))) (declaim (ftype (function * (values ntt-vector &optional)) ,convolve)) (defun ,convolve (vector1 vector2) (declare (optimize (speed 3)) (vector vector1 vector2)) (let* ((len1 (length vector1)) (len2 (length vector2)) (mul-len (max 0 (- (+ len1 len2) 1))) (vector1 (coerce vector1 'ntt-vector)) (vector2 (coerce vector2 'ntt-vector))) (declare (ntt-vector vector1 vector2) ((mod #.array-total-size-limit) mul-len)) (when (or (zerop len1) (zerop len2)) (return-from ,convolve (make-array 0 :element-type 'ntt-int))) ;; naive convolution (when (<= (min len1 len2) 64) (let ((res (make-array mul-len :element-type 'ntt-int :initial-element 0))) (declare (optimize (speed 3) (safety 0))) (dotimes (d mul-len) ;; 0 <= i <= deg1, 0 <= j <= deg2 (loop with coef of-type ntt-int = 0 for i from (max 0 (- d (- len2 1))) to (min d (- len1 1)) for j = (- d i) do (setq coef (fast-mod (+ coef (* (aref vector1 i) (aref vector2 j))) ,modulus)) finally (setf (aref res d) coef))) (return-from ,convolve res))) (let* (;; power of two ceiling (required-len (ash 1 (integer-length (max 0 (- mul-len 1))))) (vector1 (,ntt (%adjust-array vector1 required-len))) (vector2 (,ntt (%adjust-array vector2 required-len)))) (dotimes (i required-len) (setf (aref vector1 i) (fast-mod (* (aref vector1 i) (aref vector2 i)) ,modulus))) (adjust-array (,inverse-ntt vector1 t) mul-len))))))) (defconstant +ntt-mod+ 998244353) #+(or) (define-ntt +ntt-mod+) (defpackage :cp/polynomial-ntt (:use :cl :cp/ntt) (:import-from :cp/experimental/barrett #:fast-mod) (:export #:poly-multiply #:poly-inverse #:poly-floor #:poly-mod #:poly-sub #:poly-add #:multipoint-eval #:poly-total-prod)) (in-package :cp/polynomial-ntt) ;; TODO: integrate with cp/polynomial (define-ntt +ntt-mod+ :convolve poly-multiply :mod-inverse %mod-inverse) (declaim (inline %adjust)) (defun %adjust (vector size) (declare (ntt-vector vector)) (if (or (null size) (= size (length vector))) vector (let ((res (make-array size :element-type 'ntt-int :initial-element 0))) (replace res vector) res))) ;; Reference: https://opt-cp.com/fps-fast-algorithms/ (declaim (ftype (function * (values ntt-vector &optional)) poly-inverse)) (defun poly-inverse (poly &optional result-length) (declare #.cl-user::*opt* (vector poly) ((or null fixnum) result-length)) (let* ((poly (coerce poly 'ntt-vector)) (n (length poly))) (declare (ntt-vector poly)) (when (or (zerop n) (zerop (aref poly 0))) (error 'division-by-zero :operation #'poly-inverse :operands (list poly))) (let* ((result-length (or result-length n)) (res (make-array 1 :element-type 'ntt-int :initial-element (%mod-inverse (aref poly 0))))) (declare (ntt-vector res)) (loop for i of-type ntt-int = 1 then (ash i 1) while (< i result-length) for f of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int :initial-element 0) for g of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int :initial-element 0) do (replace f poly :end2 (min n (* 2 i))) (replace g res) (ntt! f) (ntt! g) (dotimes (j (* 2 i)) (setf (aref f j) (fast-mod (* (aref g j) (aref f j)) +ntt-mod+))) (inverse-ntt! f) (replace f f :start1 0 :end1 i :start2 i :end2 (* 2 i)) (fill f 0 :start i :end (* 2 i)) (ntt! f) (dotimes (j (* 2 i)) (setf (aref f j) (fast-mod (* (aref g j) (aref f j)) +ntt-mod+))) (inverse-ntt! f) (let ((inv-len (%mod-inverse (* 2 i)))) (setq inv-len (fast-mod (* inv-len (- +ntt-mod+ inv-len)) +ntt-mod+)) (dotimes (j i) (setf (aref f j) (fast-mod (* inv-len (aref f j)) +ntt-mod+))) (setq res (%adjust res (* 2 i))) (replace res f :start1 i))) (%adjust res result-length)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-floor)) (defun poly-floor (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let* ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector)) (deg1 (+ 1 (or (position 0 poly1 :from-end t :test-not #'eql) -1))) (deg2 (+ 1 (or (position 0 poly2 :from-end t :test-not #'eql) -1)))) (when (> deg2 deg1) (return-from poly-floor (make-array 0 :element-type 'ntt-int))) (setq poly1 (nreverse (subseq poly1 0 deg1)) poly2 (nreverse (subseq poly2 0 deg2))) (let* ((res-len (+ 1 (- deg1 deg2))) (res (adjust-array (poly-multiply poly1 (poly-inverse poly2 res-len)) res-len))) (nreverse res)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-sub)) (defun poly-sub (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let* ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector)) (len (max (length poly1) (length poly2))) (res (make-array len :element-type 'ntt-int :initial-element 0))) (replace res poly1) (dotimes (i (length poly2)) (let ((value (+ (aref res i) (the ntt-int (- +ntt-mod+ (aref poly2 i)))))) (setf (aref res i) (if (>= value +ntt-mod+) (- value +ntt-mod+) value)))) (let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1)))) (adjust-array res end)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-add)) (defun poly-add (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let* ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector)) (len (max (length poly1) (length poly2))) (res (make-array len :element-type 'ntt-int :initial-element 0))) (replace res poly1) (dotimes (i (length poly2)) (let ((value (+ (aref res i) (aref poly2 i)))) (setf (aref res i) (if (>= value +ntt-mod+) (- value +ntt-mod+) value)))) (let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1)))) (adjust-array res end)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-mod)) (defun poly-mod (poly1 poly2) (declare (optimize (speed 3)) (vector poly1 poly2)) (let ((poly1 (coerce poly1 'ntt-vector)) (poly2 (coerce poly2 'ntt-vector))) (when (loop for x across poly1 always (zerop x)) (return-from poly-mod (make-array 0 :element-type 'ntt-int))) (let* ((res (poly-sub poly1 (poly-multiply (poly-floor poly1 poly2) poly2))) (end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1)))) (subseq res 0 end)))) (declaim (ftype (function * (values ntt-vector &optional)) poly-total-prod)) (defun poly-total-prod (polys) "Returns the total polynomial product: polys[0] * polys[1] * ... * polys[n-1]." (declare (vector polys)) (let* ((n (length polys)) (dp (make-array n :element-type t))) (declare ((mod #.array-total-size-limit) n)) (when (zerop n) (return-from poly-total-prod (make-array 1 :element-type 'ntt-int :initial-element 1))) (replace dp polys) (loop for width of-type (mod #.array-total-size-limit) = 1 then (ash width 1) while (< width n) do (loop for i of-type (mod #.array-total-size-limit) from 0 by (* width 2) while (< (+ i width) n) do (setf (aref dp i) (poly-multiply (aref dp i) (aref dp (+ i width)))))) (coerce (aref dp 0) 'ntt-vector))) (declaim (ftype (function * (values ntt-vector &optional)) multipoint-eval)) (defun multipoint-eval (poly points) (declare #.cl-user::*opt* (vector poly points) #+sbcl (sb-ext:muffle-conditions style-warning)) (check-ntt-vector points) (let* ((poly (coerce poly 'ntt-vector)) (points (coerce points 'ntt-vector)) (len (length points)) (table (make-array (max 0 (- (* 2 len) 1)) :element-type 'ntt-vector)) (res (make-array len :element-type 'ntt-int))) (unless (zerop len) (sb-int:named-let %build ((l 0) (r len) (pos 0)) (declare ((integer 0 #.most-positive-fixnum) l r pos)) (if (= (- r l) 1) (let ((lin (make-array 2 :element-type 'ntt-int))) (setf (aref lin 0) (- +ntt-mod+ (aref points l)) ;; NOTE: non-zero (aref lin 1) 1) (setf (aref table pos) lin)) (let ((mid (ash (+ l r) -1))) (%build l mid (+ 1 (* pos 2))) (%build mid r (+ 2 (* pos 2))) (setf (aref table pos) (poly-multiply (aref table (+ 1 (* pos 2))) (aref table (+ 2 (* pos 2)))))))) (sb-int:named-let %eval ((poly poly) (l 0) (r len) (pos 0)) (declare ((integer 0 #.most-positive-fixnum) l r pos)) (if (= (- r l) 1) (let ((tmp (poly-mod poly (aref table pos)))) (setf (aref res l) (if (zerop (length tmp)) 0 (aref tmp 0)))) (let ((mid (ash (+ l r) -1))) (%eval (poly-mod poly (aref table (+ (* 2 pos) 1))) l mid (+ (* 2 pos) 1)) (%eval (poly-mod poly (aref table (+ (* 2 pos) 2))) mid r (+ (* 2 pos) 2)))))) res)) ;; BEGIN_USE_PACKAGE (eval-when (:compile-toplevel :load-toplevel :execute) (use-package :cp/experimental/barrett :cl-user)) (eval-when (:compile-toplevel :load-toplevel :execute) (use-package :cp/polynomial-ntt :cl-user)) (eval-when (:compile-toplevel :load-toplevel :execute) (use-package :cp/ntt :cl-user)) (eval-when (:compile-toplevel :load-toplevel :execute) (use-package :cp/read-fixnum :cl-user)) (in-package :cl-user) ;;; ;;; Body ;;; ;; (defun test (sample mod) ;; (let ((m (floor (ash 1 62) mod)) ;; (res 0)) ;; (dotimes (_ sample) ;; (let* ((number (random most-positive-fixnum)) ;; (q (ldb (byte 62 62) (* number m))) ;; (x (- number (* q mod)))) ;; (when (>= x mod) ;; (incf res)) ;; (when (< x 0) ;; (error "Huh?")))) ;; (values res (float (/ res sample))))) (defun main () (declare #.*opt*) (let* ((n (read-fixnum)) (m (read-fixnum)) (m2 (sb-int:power-of-two-ceiling m)) (cs (make-array n :element-type 'uint31 :initial-element 0)) (ps (make-array m2 :element-type 'uint31 :initial-element 0))) (dotimes (i n) (setf (aref cs i) (read-fixnum))) (dotimes (i m) (setf (aref ps i) (read-fixnum))) (let ((res (multipoint-eval cs ps)) (init t)) (write-string (with-output-to-string (*standard-output* nil :element-type 'base-char) (dotimes (i m (terpri)) (if init (setq init nil) (write-char #\ )) (write (aref res i)))))))) #-swank (main) ;;; ;;; Test ;;; #+swank (progn (defparameter *lisp-file-pathname* (uiop:current-lisp-file-pathname)) (setq *default-pathname-defaults* (uiop:pathname-directory-pathname *lisp-file-pathname*)) (uiop:chdir *default-pathname-defaults*) (defparameter *dat-pathname* (uiop:merge-pathnames* "test.dat" *lisp-file-pathname*)) (defparameter *problem-url* "https://www.codechef.com/problems/POLYEVAL")) #+swank (defun gen-dat () (uiop:with-output-file (out *dat-pathname* :if-exists :supersede) (format out ""))) #+swank (defun bench (&optional (out (make-broadcast-stream))) (time (run *dat-pathname* out))) #+(and sbcl (not swank)) (eval-when (:compile-toplevel) (when sb-c::*undefined-warnings* (error "undefined warnings: ~{~A~^ ~}" sb-c::*undefined-warnings*))) ;; To run: (5am:run! :sample) #+swank (5am:test :sample)