Submit Info #34206

Problem Lang User Status Time Memory
$\sum_{i=0}^{\infty} r^i i^d$ pypy3 Eki1009 AC 4598 ms 695.89 MiB

ケース詳細
Name Status Time Memory
0_00 AC 49 ms 29.83 MiB
0_01 AC 51 ms 29.80 MiB
0_02 AC 2365 ms 342.88 MiB
2_00 AC 51 ms 29.80 MiB
2_01 AC 50 ms 29.80 MiB
2_02 AC 76 ms 30.84 MiB
2_03 AC 246 ms 125.44 MiB
2_04 AC 4598 ms 695.82 MiB
2_05 AC 4323 ms 695.89 MiB
example_00 AC 51 ms 29.83 MiB

mod = 998244353 def inv_mod(a): a %= mod if a == 0: return 0 s, t = mod, a m0, m1 = 0, 1 while t: u = s // t s -= t * u m0 -= m1 * u s, t = t, s m0, m1 = m1, m0 if m0 < 0: m0 += mod // s return m0 fac_ = [1, 1] finv_ = [1, 1] inv_ = [1, 1] def fac(i): while i >= len(fac_): fac_.append(fac_[-1] * len(fac_) % mod) return fac_[i] def finv(i): while i >= len(inv_): inv_.append((-inv_[mod % len(inv_)]) * (mod // len(inv_)) % mod) while i >= len(finv_): finv_.append(finv_[-1] * inv_[len(finv_)] % mod) return finv_[i] def inv(i): while i >= len(inv_): inv_.append((-inv_[mod % len(inv_)]) * (mod // len(inv_)) % mod) return inv_[i] def comb(n, k): if k < 0 or n < k: return 0 return fac(n) * finv(k) % mod * finv(n - k) % mod #y(a)を求める def lagrange_interpotation(y, a): n = len(y) - 1 if x <= n: return y[x] res = 0 dp = [1] * (n + 1) pd = [1] * (n + 1) for i in range(n): dp[i + 1] = dp[i] * a % mod a -= 1 for i in range(n, 0, -1): pd[i - 1] = pd[i] * a % mod a += 1 for i in range(n + 1): temp = y[i] * dp[i] % mod * pd[i] % mod * finv(i) % mod * finv(n - i) % mod res += (1 - ((n - i) & 1) * 2) * temp res %= mod return res #\sum_{i=0}^{n-1} a^i f(i) def sum_of_exp(f, a, n): if n == 0: return 0 if a == 0: return f[0] if a == 1: g = [0] * (len(f) + 1) for i in range(1, len(g)): g[i] = (g[i - 1] + f[i - 1]) % mod return lagrange_interpotation(g, n) k = len(f) - 1 g = [0] * len(f) buf = 1 for i in range(len(g)): g[i] = f[i] * buf % mod buf *= a buf %= mod for i in range(1, len(g)): g[i] += g[i - 1] g[i] %= mod c, buf2 = 0, 1 for i in range(k + 1): c += comb(k + 1, i) * buf2 % mod * g[k - i] % mod c %= mod buf2 *= -a buf2 %= mod c *= inv_mod(pow(-a + 1, k + 1, mod)) c %= mod buf3 = 1 ia = inv_mod(a) for i in range(len(g)): g[i] = (g[i] - c) * buf3 % mod buf3 *= ia buf3 %= mod tn = lagrange_interpotation(g, n - 1) return (tn * pow(a, n - 1, mod) + c) % mod #\sum_{i=0}^{\infty} a^i f(i) def sum_of_exp_limit(f, a): if a == 0: return f[0] k = len(f) - 1 g = [0] * (len(f)) buf = 1 for i in range(len(g)): g[i] = f[i] * buf % mod buf *= a buf %= mod for i in range(1, len(g)): g[i] += g[i - 1] g[i] %= mod c, buf2 = 0, 1 for i in range(k + 1): c += comb(k + 1, i) * buf2 % mod * g[k - i] c %= mod buf2 *= -a buf2 %= mod c *= inv_mod(pow(-a + 1, k + 1, mod)) c %= mod return c #(0^p,1^p,...,n^p)を返す def exp_enumerate(p, n): if not p: return [1] * (n + 1) f = [0] * (n + 1) f[1] = 1 sieve = [0] * (n + 1) ps = [] for i in range(2, n + 1): if not sieve[i]: f[i] = pow(i, p, mod) ps.append(i) for j in range(len(ps)): if i * ps[j] > n: break sieve[i * ps[j]] = 1 f[i * ps[j]] = f[i] * f[ps[j]] % mod if i % ps[j] == 0: break return f r, d = map(int, input().split()) F = exp_enumerate(d, d) ans = sum_of_exp_limit(F, r) print(ans)