#-swank
(unless (member :child-sbcl *features*)
(quit
:unix-status
(process-exit-code
(run-program *runtime-pathname*
`("--control-stack-size" "256MB"
"--noinform" "--disable-ldb" "--lose-on-corruption" "--end-runtime-options"
"--eval" "(push :child-sbcl *features*)"
"--script" ,(namestring *load-pathname*))
:output t :error t :input t))))
(in-package :cl-user)
(eval-when (:compile-toplevel :load-toplevel :execute)
(defparameter *opt*
#+swank '(optimize (speed 3) (safety 2))
#-swank '(optimize (speed 3) (safety 0) (debug 0)))
#+swank (ql:quickload '(:cl-debug-print :fiveam :cp/util) :silent t)
#+swank (use-package :cp/util :cl-user)
#-swank (set-dispatch-macro-character
#\# #\> (lambda (s c p) (declare (ignore c p)) `(values ,(read s nil nil t))))
#+sbcl (dolist (f '(:popcnt :sse4)) (pushnew f sb-c:*backend-subfeatures*))
#+sbcl (setq *random-state* (seed-random-state (nth-value 1 (get-time-of-day)))))
#-swank (eval-when (:compile-toplevel)
(setq *break-on-signals* '(and warning (not style-warning))))
#+swank (set-dispatch-macro-character #\# #\> #'cl-debug-print:debug-print-reader)
(macrolet ((def (b)
`(progn (deftype ,(intern (format nil "UINT~A" b)) () '(unsigned-byte ,b))
(deftype ,(intern (format nil "INT~A" b)) () '(signed-byte ,b))))
(define-int-types (&rest bits) `(progn ,@(mapcar (lambda (b) `(def ,b)) bits))))
(define-int-types 2 4 7 8 15 16 31 32 62 63 64))
(defconstant +mod+ 1000000007)
(defmacro dbg (&rest forms)
(declare (ignorable forms))
#+swank (if (= (length forms) 1)
`(format *error-output* "~A => ~A~%" ',(car forms) ,(car forms))
`(format *error-output* "~A => ~A~%" ',forms `(,,@forms))))
(declaim (inline println))
(defun println (obj &optional (stream *standard-output*))
(let ((*read-default-float-format*
(if (typep obj 'double-float) 'double-float *read-default-float-format*)))
(prog1 (princ obj stream) (terpri stream))))
;; BEGIN_INSERTED_CONTENTS
(defpackage :cp/experimental/barrett
(:use :cl)
(:import-from #:sb-ext #:truly-the)
(:import-from #:sb-c
#:deftransform
#:derive-type
#:defoptimizer
#:defknown #:movable #:foldable #:flushable #:commutative
#:always-translatable
#:lvar-type
#:lvar-value
#:give-up-ir1-transform
#:integer-type-numeric-bounds
#:define-vop)
(:import-from #:sb-vm
#:move #:inst #:eax-offset #:edx-offset
#:any-reg #:control-stack #:unsigned-reg
#:positive-fixnum
#:fixnumize #:ea)
(:import-from #:sb-kernel #:specifier-type)
(:import-from #:sb-int #:explicit-check #:constant-arg)
(:export #:fast-mod #:%himod #:%lomod)
(:documentation "Provides Barrett reduction."))
(in-package :cp/experimental/barrett)
(eval-when (:compile-toplevel :load-toplevel :execute)
(defun derive-* (x y)
(let ((high1 (nth-value 1 (integer-type-numeric-bounds (lvar-type x))))
(high2 (nth-value 1 (integer-type-numeric-bounds (lvar-type y)))))
(specifier-type (if (and (integerp high1) (integerp high2))
`(integer 0 ,(ash (* high1 high2) -62))
`(integer 0)))))
(defun derive-mod (modulus)
(let ((high (nth-value 1 (integer-type-numeric-bounds (lvar-type modulus)))))
(specifier-type (if (integerp high)
`(integer 0 (,high))
`(integer 0)))))
(defun gpr-tn-p (x)
(declare (ignorable x))
#.(if (find-symbol "GPR-TN-P" :sb-vm)
`(funcall (intern "GPR-TN-P" :sb-vm) x)
t)))
(eval-when (:compile-toplevel :load-toplevel :execute)
;; *-high62
(defknown *-high62 ((unsigned-byte 62) (unsigned-byte 62)) (unsigned-byte 62)
(movable foldable flushable commutative sb-c::always-translatable)
:overwrite-fndb-silently t)
(defoptimizer (*-high62 derive-type) ((x y))
(derive-* x y))
(define-vop (fast-*-high62/fixnum)
(:translate *-high62)
(:policy :fast-safe)
(:args (x :scs (any-reg) :target eax)
(y :scs (any-reg control-stack)))
(:arg-types positive-fixnum positive-fixnum)
(:temporary (:sc any-reg :offset eax-offset
:from (:argument 0) :to :result)
eax)
(:temporary (:sc any-reg :offset edx-offset :target r
:from :eval :to :result)
edx)
(:results (r :scs (any-reg)))
(:result-types positive-fixnum)
(:note "inline *-high62")
(:vop-var vop)
(:save-p :compute-only)
(:generator 6
(move eax x)
(inst mul eax y)
(inst shl edx 1)
(move r edx)))
(define-vop (fast-c-*-high62-/fixnum)
(:translate *-high62)
(:policy :fast-safe)
(:args (x :scs (any-reg) :target eax))
(:info y)
(:arg-types positive-fixnum (:constant (unsigned-byte 62)))
(:temporary (:sc any-reg :offset eax-offset
:from (:argument 0) :to :result)
eax)
(:temporary (:sc any-reg :offset edx-offset :target r
:from :eval :to :result)
edx)
(:results (r :scs (any-reg)))
(:result-types positive-fixnum)
(:note "inline constant *-high62")
(:vop-var vop)
(:save-p :compute-only)
(:generator 6
(move eax x)
(inst mul eax (sb-c:register-inline-constant :qword y))
(inst shl edx 1)
(move r edx)))
(defun *-high62 (x y)
(declare (explicit-check))
(*-high62 x y)))
;; %himod
(eval-when (:compile-toplevel :load-toplevel :execute)
(defknown %himod ((unsigned-byte 32) (unsigned-byte 31)) (unsigned-byte 31)
(movable foldable flushable sb-c:always-translatable)
:overwrite-fndb-silently t)
(defoptimizer (%himod derive-type) ((integer modulus))
(declare (ignore integer))
(derive-mod modulus))
(define-vop (fast-c-%himod)
(:translate %himod)
(:policy :fast-safe)
(:args (x :scs (any-reg) :target r))
(:info m)
(:arg-types positive-fixnum (:constant (unsigned-byte 31)))
(:temporary (:sc any-reg :from :eval :to :result) y)
(:results (r :scs (any-reg)))
(:result-types positive-fixnum)
(:note "inline constant %himod")
(:vop-var vop)
(:generator
4
;; maybe verbose
(assert (gpr-tn-p x))
(when (sb-c:tn-p m)
(assert (sb-c:sc-is m sb-vm::immediate))
(setq m (sb-c::tn-value m)))
(setq m (fixnumize m))
(move r x)
(inst cmp r (- m 2))
(inst lea y #.(if (fboundp 'ea)
`(sb-vm::ea (- m) r)
`(,(find-symbol "MAKE-EA" :sb-vm) :dword :disp (- m) :base r)))
(inst cmov :a r y))))
;; %lomod
(eval-when (:compile-toplevel :load-toplevel :execute)
(defknown %lomod ((signed-byte 32) (unsigned-byte 31)) (unsigned-byte 31)
(movable foldable flushable always-translatable)
:overwrite-fndb-silently t)
(defoptimizer (%lomod derive-type) ((integer modulus))
(declare (ignore integer))
(derive-mod modulus))
(define-vop (fast-c-%lomod)
(:translate %lomod)
(:policy :fast-safe)
(:args (x :scs (any-reg) :target r))
(:info m)
(:arg-types fixnum (:constant (unsigned-byte 31)))
(:temporary (:sc any-reg :from :eval :to :result) y)
(:results (r :scs (any-reg)))
(:result-types positive-fixnum)
(:note "inline constant %himod")
(:vop-var vop)
(:generator
4
(assert (gpr-tn-p x))
(when (sb-c:tn-p m)
(assert (sb-c:sc-is m sb-vm::immediate))
(setq m (sb-c::tn-value m)))
(setq m (fixnumize m))
(move r x)
(inst or r r)
(inst lea y #.(if (fboundp 'ea)
`(sb-vm::ea m r)
`(,(find-symbol "MAKE-EA" :sb-vm) :dword :disp m :base r)))
(inst cmov :l r y))))
;; fast-mod
(eval-when (:compile-toplevel :load-toplevel :execute)
(defknown fast-mod (integer unsigned-byte) unsigned-byte
(movable foldable flushable)
:overwrite-fndb-silently t)
(defoptimizer (fast-mod derive-type) ((integer modulus))
(declare (ignore integer))
(derive-mod modulus))
(defun fast-mod (number divisor)
"Is equivalent to CL:MOD for integer arguments"
(declare (explicit-check))
(mod number divisor))
(deftransform fast-mod ((number divisor)
((unsigned-byte 62) (constant-arg (unsigned-byte 31))) *
:important t)
"convert FAST-MOD to barrett reduction"
(let ((mod (lvar-value divisor)))
(if (<= mod 1)
(give-up-ir1-transform)
(let ((m (floor (ash 1 62) mod)))
`(let* ((q (*-high62 number ,m))
(x (truly-the (signed-byte 32) (- number (* q ,mod)))))
(if (< x ,mod) x (- x ,mod))
;; Not so effective because branch prediction is usually correct?
;; (sb-ext:truly-the (mod ,mod) (%himod x ,mod))
))))))
(defpackage :cp/read-fixnum
(:use :cl)
(:export #:read-fixnum))
(in-package :cp/read-fixnum)
(declaim (ftype (function * (values fixnum &optional)) read-fixnum))
(defun read-fixnum (&optional (in *standard-input*))
"NOTE: cannot read -2^62"
(macrolet ((%read-byte ()
`(the (unsigned-byte 8)
#+swank (char-code (read-char in nil #\Nul))
#-swank (sb-impl::ansi-stream-read-byte in nil #.(char-code #\Nul) nil))))
(let* ((minus nil)
(result (loop (let ((byte (%read-byte)))
(cond ((<= 48 byte 57)
(return (- byte 48)))
((zerop byte) ; #\Nul
(error "Read EOF or #\Nul."))
((= byte #.(char-code #\-))
(setq minus t)))))))
(declare ((integer 0 #.most-positive-fixnum) result))
(loop
(let* ((byte (%read-byte)))
(if (<= 48 byte 57)
(setq result (+ (- byte 48)
(* 10 (the (integer 0 #.(floor most-positive-fixnum 10))
result))))
(return (if minus (- result) result))))))))
;;;
;;; Fast Number Theoretic Transform
;;; Reference:
;;; https://github.com/ei1333/library/blob/master/math/fft/number-theoretic-transform-friendly-mod-int.cpp
;;; https://github.com/atcoder/ac-library/tree/master/atcoder
;;;
(defpackage :cp/ntt
(:use :cl)
(:import-from :cp/experimental/barrett #:fast-mod #:%himod #:%lomod)
(:export #:define-ntt #:check-ntt-vector #:ntt-int #:ntt-vector #:+ntt-mod+))
(in-package :cp/ntt)
(deftype ntt-int () '(unsigned-byte 31))
(deftype ntt-vector () '(simple-array ntt-int (*)))
(eval-when (:compile-toplevel :load-toplevel :execute)
(declaim (inline %tzcount))
(defun %tzcount (x)
"Returns the number of the trailing zero bits. Note that (%TZCOUNT 0) = -1."
(- (integer-length (logand x (- x))) 1))
(defun %mod-power (base exp modulus)
(declare (ntt-int base exp modulus))
(let ((res 1))
(declare (ntt-int res))
(loop while (> exp 0)
when (oddp exp)
do (setq res (fast-mod (* res base) modulus))
do (setq base (fast-mod (* base base) modulus)
exp (ash exp -1)))
res))
(defun %mod-inverse (x modulus)
(%mod-power x (- modulus 2) modulus))
(defun %calc-generator (modulus)
"MODULUS must be prime."
(declare (ntt-int modulus))
(assert (>= modulus 2))
(case modulus
(2 1)
(167772161 3)
(469762049 3)
(754974721 11)
(998244353 3)
(otherwise
(let ((divs (make-array 20 :element-type 'ntt-int :initial-element 0))
(end 1)
(x (floor (- modulus 1) 2)))
(declare ((integer 0 #.most-positive-fixnum) x))
(setf (aref divs 0) 2)
(loop while (evenp x)
do (setq x (floor x 2)))
(loop for i of-type ntt-int from 3 by 2
while (<= (* i i) x)
when (zerop (mod x i))
do (setf (aref divs end) i)
(incf end)
(loop while (zerop (mod x i))
do (setq x (floor x i))))
(when (> x 1)
(setf (aref divs end) x)
(incf end))
(loop for g of-type ntt-int from 2
do (dotimes (i end (return-from %calc-generator g))
(when (= 1 (%mod-power g (floor (- modulus 1) (aref divs i)) modulus))
(return)))))))))
;; KLUDGE: This function depends on SBCL's behaviour. That is, ADJUST-ARRAY
;; isn't guaranteed to preserve a given VECTOR in ANSI CL.
(declaim (ftype (function * (values ntt-vector &optional)) %adjust-array))
(defun %adjust-array (vector length)
(declare (vector vector))
(let ((vector (coerce vector 'ntt-vector)))
(if (= (length vector) length)
(copy-seq vector)
(adjust-array vector length :initial-element 0))))
(defun check-ntt-vector (vector)
(declare (optimize (speed 3))
(vector vector))
(let ((len (length vector)))
(assert (zerop (logand len (- len 1)))) ;; power of two
(check-type len ntt-int)))
(defmacro define-ntt (modulus &key ntt inverse-ntt convolve mod-inverse mod-power
&environment env)
(assert (constantp modulus env))
(let* ((modulus #+sbcl (sb-int:constant-form-value modulus env) #-sbcl modulus)
(ntt (or ntt (intern "NTT!")))
(inverse-ntt (or inverse-ntt (intern "INVERSE-NTT!")))
(convolve (or convolve (intern "CONVOLVE")))
(mod-power (or mod-power (gensym "MOD-POWER")))
(mod-inverse (or mod-inverse (gensym "MOD-INVERSE")))
(ntt-base (gensym "*NTT-BASE*"))
(ntt-inv-base (gensym "*NTT-INV-BASE*"))
(base-size (%tzcount (- modulus 1)))
(root (%calc-generator modulus))
(modulus (sb-int:constant-form-value modulus env)))
(declare (ntt-int modulus))
`(progn
(declaim (inline ,mod-power))
(defun ,mod-power (base exp)
(declare (ntt-int base)
((integer 0 #.most-positive-fixnum) exp))
(let ((res 1))
(declare (ntt-int res))
(loop while (> exp 0)
when (oddp exp)
do (setq res (fast-mod (* res base) ,modulus))
do (setq base (fast-mod (* base base) ,modulus)
exp (ash exp -1)))
res))
(declaim (inline ,mod-inverse))
(defun ,mod-inverse (x)
(,mod-power x (- ,modulus 2)))
(declaim (ntt-vector ,ntt-base ,ntt-inv-base))
(sb-ext:define-load-time-global ,ntt-base
(make-array ,base-size :element-type 'ntt-int))
(sb-ext:define-load-time-global ,ntt-inv-base
(make-array ,base-size :element-type 'ntt-int))
(dotimes (i ,base-size)
(setf (aref ,ntt-base i)
(mod (- (%mod-power ,root (ash (- ,modulus 1) (- (+ i 2))) ,modulus))
,modulus)
(aref ,ntt-inv-base i)
(%mod-inverse (aref ,ntt-base i) ,modulus)))
(declaim (ftype (function * (values ntt-vector &optional)) ,ntt))
(defun ,ntt (vector)
(declare (optimize (speed 3) (safety 0))
(vector vector))
(check-ntt-vector vector)
(labels ((mod* (x y) (fast-mod (* x y) ,modulus))
(mod+ (x y) (%himod (+ x y) ,modulus))
(mod- (x y) (%lomod (- x y) ,modulus)))
(declare (inline mod* mod+ mod-))
(let* ((vector (coerce vector 'ntt-vector))
(len (length vector))
(base ,ntt-base))
(declare (ntt-vector vector base)
(ntt-int len))
(when (<= len 1)
(return-from ,ntt vector))
(loop for m of-type ntt-int = (ash len -1) then (ash m -1)
while (> m 0)
for w of-type ntt-int = 1
for k of-type ntt-int = 0
do (loop for s of-type ntt-int from 0 below len by (* 2 m)
do (loop for i from s below (+ s m)
for j from (+ s m)
for x = (aref vector i)
for y = (mod* (aref vector j) w)
do (setf (aref vector i) (mod+ x y)
(aref vector j) (mod- x y)))
(incf k)
(setq w (mod* w (aref base (%tzcount k))))))
vector)))
(declaim (ftype (function * (values ntt-vector &optional)) ,inverse-ntt))
(defun ,inverse-ntt (vector &optional inverse)
(declare (optimize (speed 3) (safety 0))
(vector vector))
(check-ntt-vector vector)
(labels ((mod* (x y)
(declare (ntt-int x y))
(fast-mod (* x y) ,modulus))
(mod+ (x y) (%himod (+ x y) ,modulus))
(mod- (x y) (%lomod (- x y) ,modulus)))
(declare (inline mod* mod+ mod-))
(let* ((vector (coerce vector 'ntt-vector))
(len (length vector))
(base ,ntt-inv-base))
(declare (ntt-vector vector base)
(ntt-int len))
(when (<= len 1)
(return-from ,inverse-ntt vector))
(loop for m of-type ntt-int = 1 then (ash m 1)
while (< m len)
for w of-type ntt-int = 1
for k of-type ntt-int = 0
do (loop for s of-type ntt-int from 0 below len by (* 2 m)
do (loop for i from s below (+ s m)
for j from (+ s m)
for x = (aref vector i)
for y = (aref vector j)
do (setf (aref vector i) (mod+ x y)
(aref vector j) (mod* (mod- x y) w)))
(incf k)
(setq w (mod* w (aref base (%tzcount k))))))
(when inverse
(let ((inv-len (,mod-power len (- ,modulus 2))))
(dotimes (i len)
(setf (aref vector i) (mod* inv-len (aref vector i))))))
vector)))
(declaim (ftype (function * (values ntt-vector &optional)) ,convolve))
(defun ,convolve (vector1 vector2)
(declare (optimize (speed 3))
(vector vector1 vector2))
(let* ((len1 (length vector1))
(len2 (length vector2))
(mul-len (max 0 (- (+ len1 len2) 1)))
(vector1 (coerce vector1 'ntt-vector))
(vector2 (coerce vector2 'ntt-vector)))
(declare (ntt-vector vector1 vector2)
((mod #.array-total-size-limit) mul-len))
(when (or (zerop len1) (zerop len2))
(return-from ,convolve (make-array 0 :element-type 'ntt-int)))
;; naive convolution
(when (<= (min len1 len2) 64)
(let ((res (make-array mul-len :element-type 'ntt-int :initial-element 0)))
(declare (optimize (speed 3) (safety 0)))
(dotimes (d mul-len)
;; 0 <= i <= deg1, 0 <= j <= deg2
(loop with coef of-type ntt-int = 0
for i from (max 0 (- d (- len2 1))) to (min d (- len1 1))
for j = (- d i)
do (setq coef (fast-mod
(+ coef (* (aref vector1 i) (aref vector2 j)))
,modulus))
finally (setf (aref res d) coef)))
(return-from ,convolve res)))
(let* (;; power of two ceiling
(required-len (ash 1 (integer-length (max 0 (- mul-len 1)))))
(vector1 (,ntt (%adjust-array vector1 required-len)))
(vector2 (,ntt (%adjust-array vector2 required-len))))
(dotimes (i required-len)
(setf (aref vector1 i)
(fast-mod (* (aref vector1 i) (aref vector2 i)) ,modulus)))
(adjust-array (,inverse-ntt vector1 t) mul-len)))))))
(defconstant +ntt-mod+ 998244353)
#+(or)
(define-ntt +ntt-mod+)
(defpackage :cp/polynomial-ntt
(:use :cl :cp/ntt)
(:import-from :cp/experimental/barrett #:fast-mod)
(:export #:poly-multiply #:poly-inverse #:poly-floor #:poly-mod #:poly-sub #:poly-add
#:multipoint-eval #:poly-total-prod))
(in-package :cp/polynomial-ntt)
;; TODO: integrate with cp/polynomial
(define-ntt +ntt-mod+
:convolve poly-multiply
:mod-inverse %mod-inverse)
(declaim (inline %adjust))
(defun %adjust (vector size)
(declare (ntt-vector vector))
(if (or (null size) (= size (length vector)))
vector
(let ((res (make-array size :element-type 'ntt-int :initial-element 0)))
(replace res vector)
res)))
;; Reference: https://opt-cp.com/fps-fast-algorithms/
(declaim (ftype (function * (values ntt-vector &optional)) poly-inverse))
(defun poly-inverse (poly &optional result-length)
(declare #.cl-user::*opt*
(vector poly)
((or null fixnum) result-length))
(let* ((poly (coerce poly 'ntt-vector))
(n (length poly)))
(declare (ntt-vector poly))
(when (or (zerop n)
(zerop (aref poly 0)))
(error 'division-by-zero
:operation #'poly-inverse
:operands (list poly)))
(let* ((result-length (or result-length n))
(res (make-array 1
:element-type 'ntt-int
:initial-element (%mod-inverse (aref poly 0)))))
(declare (ntt-vector res))
(loop for i of-type ntt-int = 1 then (ash i 1)
while (< i result-length)
for f of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int
:initial-element 0)
for g of-type ntt-vector = (make-array (* 2 i) :element-type 'ntt-int
:initial-element 0)
do (replace f poly :end2 (min n (* 2 i)))
(replace g res)
(ntt! f)
(ntt! g)
(dotimes (j (* 2 i))
(setf (aref f j) (fast-mod (* (aref g j) (aref f j)) +ntt-mod+)))
(inverse-ntt! f)
(replace f f :start1 0 :end1 i :start2 i :end2 (* 2 i))
(fill f 0 :start i :end (* 2 i))
(ntt! f)
(dotimes (j (* 2 i))
(setf (aref f j) (fast-mod (* (aref g j) (aref f j)) +ntt-mod+)))
(inverse-ntt! f)
(let ((inv-len (%mod-inverse (* 2 i))))
(setq inv-len (fast-mod (* inv-len (- +ntt-mod+ inv-len))
+ntt-mod+))
(dotimes (j i)
(setf (aref f j) (fast-mod (* inv-len (aref f j)) +ntt-mod+)))
(setq res (%adjust res (* 2 i)))
(replace res f :start1 i)))
(%adjust res result-length))))
(declaim (ftype (function * (values ntt-vector &optional)) poly-floor))
(defun poly-floor (poly1 poly2)
(declare (optimize (speed 3))
(vector poly1 poly2))
(let* ((poly1 (coerce poly1 'ntt-vector))
(poly2 (coerce poly2 'ntt-vector))
(deg1 (+ 1 (or (position 0 poly1 :from-end t :test-not #'eql) -1)))
(deg2 (+ 1 (or (position 0 poly2 :from-end t :test-not #'eql) -1))))
(when (> deg2 deg1)
(return-from poly-floor (make-array 0 :element-type 'ntt-int)))
(setq poly1 (nreverse (subseq poly1 0 deg1))
poly2 (nreverse (subseq poly2 0 deg2)))
(let* ((res-len (+ 1 (- deg1 deg2)))
(res (adjust-array (poly-multiply poly1 (poly-inverse poly2 res-len))
res-len)))
(nreverse res))))
(declaim (ftype (function * (values ntt-vector &optional)) poly-sub))
(defun poly-sub (poly1 poly2)
(declare (optimize (speed 3))
(vector poly1 poly2))
(let* ((poly1 (coerce poly1 'ntt-vector))
(poly2 (coerce poly2 'ntt-vector))
(len (max (length poly1) (length poly2)))
(res (make-array len :element-type 'ntt-int :initial-element 0)))
(replace res poly1)
(dotimes (i (length poly2))
(let ((value (+ (aref res i)
(the ntt-int (- +ntt-mod+ (aref poly2 i))))))
(setf (aref res i)
(if (>= value +ntt-mod+)
(- value +ntt-mod+)
value))))
(let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1))))
(adjust-array res end))))
(declaim (ftype (function * (values ntt-vector &optional)) poly-add))
(defun poly-add (poly1 poly2)
(declare (optimize (speed 3))
(vector poly1 poly2))
(let* ((poly1 (coerce poly1 'ntt-vector))
(poly2 (coerce poly2 'ntt-vector))
(len (max (length poly1) (length poly2)))
(res (make-array len :element-type 'ntt-int :initial-element 0)))
(replace res poly1)
(dotimes (i (length poly2))
(let ((value (+ (aref res i) (aref poly2 i))))
(setf (aref res i)
(if (>= value +ntt-mod+)
(- value +ntt-mod+)
value))))
(let ((end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1))))
(adjust-array res end))))
(declaim (ftype (function * (values ntt-vector &optional)) poly-mod))
(defun poly-mod (poly1 poly2)
(declare (optimize (speed 3))
(vector poly1 poly2))
(let ((poly1 (coerce poly1 'ntt-vector))
(poly2 (coerce poly2 'ntt-vector)))
(when (loop for x across poly1 always (zerop x))
(return-from poly-mod (make-array 0 :element-type 'ntt-int)))
(let* ((res (poly-sub poly1 (poly-multiply (poly-floor poly1 poly2) poly2)))
(end (+ 1 (or (position 0 res :from-end t :test-not #'eql) -1))))
(subseq res 0 end))))
(declaim (ftype (function * (values ntt-vector &optional)) poly-total-prod))
(defun poly-total-prod (polys)
"Returns the total polynomial product: polys[0] * polys[1] * ... * polys[n-1]."
(declare (vector polys))
(let* ((n (length polys))
(dp (make-array n :element-type t)))
(declare ((mod #.array-total-size-limit) n))
(when (zerop n)
(return-from poly-total-prod (make-array 1 :element-type 'ntt-int :initial-element 1)))
(replace dp polys)
(loop for width of-type (mod #.array-total-size-limit) = 1 then (ash width 1)
while (< width n)
do (loop for i of-type (mod #.array-total-size-limit) from 0 by (* width 2)
while (< (+ i width) n)
do (setf (aref dp i)
(poly-multiply (aref dp i) (aref dp (+ i width))))))
(coerce (aref dp 0) 'ntt-vector)))
(declaim (ftype (function * (values ntt-vector &optional)) multipoint-eval))
(defun multipoint-eval (poly points)
(declare #.cl-user::*opt*
(vector poly points)
#+sbcl (sb-ext:muffle-conditions style-warning))
(check-ntt-vector points)
(let* ((poly (coerce poly 'ntt-vector))
(points (coerce points 'ntt-vector))
(len (length points))
(table (make-array (max 0 (- (* 2 len) 1)) :element-type 'ntt-vector))
(res (make-array len :element-type 'ntt-int)))
(unless (zerop len)
(sb-int:named-let %build ((l 0) (r len) (pos 0))
(declare ((integer 0 #.most-positive-fixnum) l r pos))
(if (= (- r l) 1)
(let ((lin (make-array 2 :element-type 'ntt-int)))
(setf (aref lin 0) (- +ntt-mod+ (aref points l)) ;; NOTE: non-zero
(aref lin 1) 1)
(setf (aref table pos) lin))
(let ((mid (ash (+ l r) -1)))
(%build l mid (+ 1 (* pos 2)))
(%build mid r (+ 2 (* pos 2)))
(setf (aref table pos)
(poly-multiply (aref table (+ 1 (* pos 2)))
(aref table (+ 2 (* pos 2))))))))
(sb-int:named-let %eval ((poly poly) (l 0) (r len) (pos 0))
(declare ((integer 0 #.most-positive-fixnum) l r pos))
(if (= (- r l) 1)
(let ((tmp (poly-mod poly (aref table pos))))
(setf (aref res l) (if (zerop (length tmp)) 0 (aref tmp 0))))
(let ((mid (ash (+ l r) -1)))
(%eval (poly-mod poly (aref table (+ (* 2 pos) 1))) l mid (+ (* 2 pos) 1))
(%eval (poly-mod poly (aref table (+ (* 2 pos) 2))) mid r (+ (* 2 pos) 2))))))
res))
;; BEGIN_USE_PACKAGE
(eval-when (:compile-toplevel :load-toplevel :execute)
(use-package :cp/experimental/barrett :cl-user))
(eval-when (:compile-toplevel :load-toplevel :execute)
(use-package :cp/polynomial-ntt :cl-user))
(eval-when (:compile-toplevel :load-toplevel :execute)
(use-package :cp/ntt :cl-user))
(eval-when (:compile-toplevel :load-toplevel :execute)
(use-package :cp/read-fixnum :cl-user))
(in-package :cl-user)
;;;
;;; Body
;;;
;; (defun test (sample mod)
;; (let ((m (floor (ash 1 62) mod))
;; (res 0))
;; (dotimes (_ sample)
;; (let* ((number (random most-positive-fixnum))
;; (q (ldb (byte 62 62) (* number m)))
;; (x (- number (* q mod))))
;; (when (>= x mod)
;; (incf res))
;; (when (< x 0)
;; (error "Huh?"))))
;; (values res (float (/ res sample)))))
(defun main ()
(declare #.*opt*)
(let* ((n (read-fixnum))
(m (read-fixnum))
(m2 (sb-int:power-of-two-ceiling m))
(cs (make-array n :element-type 'uint31 :initial-element 0))
(ps (make-array m2 :element-type 'uint31 :initial-element 0)))
(dotimes (i n)
(setf (aref cs i) (read-fixnum)))
(dotimes (i m)
(setf (aref ps i) (read-fixnum)))
(let ((res (multipoint-eval cs ps))
(init t))
(write-string
(with-output-to-string (*standard-output* nil :element-type 'base-char)
(dotimes (i m (terpri))
(if init
(setq init nil)
(write-char #\ ))
(write (aref res i))))))))
#-swank (main)
;;;
;;; Test
;;;
#+swank
(progn
(defparameter *lisp-file-pathname* (uiop:current-lisp-file-pathname))
(setq *default-pathname-defaults* (uiop:pathname-directory-pathname *lisp-file-pathname*))
(uiop:chdir *default-pathname-defaults*)
(defparameter *dat-pathname* (uiop:merge-pathnames* "test.dat" *lisp-file-pathname*))
(defparameter *problem-url* "https://www.codechef.com/problems/POLYEVAL"))
#+swank
(defun gen-dat ()
(uiop:with-output-file (out *dat-pathname* :if-exists :supersede)
(format out "")))
#+swank
(defun bench (&optional (out (make-broadcast-stream)))
(time (run *dat-pathname* out)))
#+(and sbcl (not swank))
(eval-when (:compile-toplevel)
(when sb-c::*undefined-warnings*
(error "undefined warnings: ~{~A~^ ~}" sb-c::*undefined-warnings*)))
;; To run: (5am:run! :sample)
#+swank
(5am:test :sample)