Submit Info #14373

Problem Lang User Status Time Memory
Dynamic Tree Subtree Add Subtree Sum cpp QCFium AC 562 ms 16.93 MiB

ケース詳細
Name Status Time Memory
example_00 AC 1 ms 0.67 MiB
max_random_00 AC 562 ms 16.92 MiB
max_random_01 AC 556 ms 16.92 MiB
max_random_02 AC 540 ms 16.92 MiB
max_random_03 AC 527 ms 16.93 MiB
max_random_04 AC 533 ms 16.92 MiB
random_00 AC 359 ms 11.17 MiB
random_01 AC 371 ms 12.95 MiB
random_02 AC 227 ms 5.29 MiB
random_03 AC 150 ms 13.59 MiB
random_04 AC 154 ms 2.29 MiB
small_00 AC 1 ms 0.74 MiB
small_01 AC 2 ms 0.72 MiB
small_02 AC 2 ms 0.68 MiB
small_03 AC 3 ms 0.68 MiB
small_04 AC 2 ms 0.67 MiB

#include <bits/stdc++.h> int ri() { int n; scanf("%d", &n); return n; } struct Node; extern Node *NONE; struct Node { #define l ch[0] #define r ch[1] Node *ch[2] = {NONE, NONE}; Node *p = NONE; int64_t val = 0; int64_t sum = 0; int64_t added = 0; int64_t cancel = 0; int size = 0; int64_t light_sum = 0; // LIGHT int light_size = 0; // LIGHT bool rev = false; void fetch() { if (l != NONE) l->flush(); if (r != NONE) r->flush(); sum = val + l->sum + r->sum + light_sum; size = 1 + l->size + r->size + light_size; } void add(int64_t add_val) { val += add_val; sum += add_val * size; added += add_val; light_sum += add_val * light_size; } void flush() { if (p != NONE) { add(p->added - cancel); cancel = p->added; } if (rev) { std::swap(l, r); l->rev ^= 1; r->rev ^= 1; rev = false; } } void rotate(int dir) { Node *new_root = ch[!dir]; assert(new_root != NONE); if (new_root->ch[dir] != NONE) new_root->ch[dir]->flush(); ch[!dir] = new_root->ch[dir]; ch[!dir]->p = this; ch[!dir]->cancel = added; new_root->ch[dir] = this; if (p->l == this) p->l = new_root; if (p->r == this) p->r = new_root; new_root->p = p; new_root->cancel = p->added; p = new_root; cancel = new_root->added; fetch(), new_root->fetch(); } bool is_root() { return p == NONE || (p->l != this && p->r != this); } void splay() { while (!is_root()) { if (p->is_root()) { p->flush(), flush(); p->rotate(p->l == this); } else { Node *pp = p->p; pp->flush(), p->flush(), flush(); bool flag0 = pp->l == p; bool flag1 = p->l == this; if (flag0 == flag1) pp->rotate(flag0); p->rotate(flag1); if (flag0 != flag1) pp->rotate(flag0); } } flush(); } Node *expose() { Node *prev = NONE; for (Node *cur = this; cur != NONE; cur = cur->p) { cur->splay(); if (cur->r != NONE) { // add cur->r->flush(); cur->light_size += cur->r->size; cur->light_sum += cur->r->sum; } cur->r = prev; if (cur->r != NONE) { // remove cur->r->flush(); cur->light_size -= cur->r->size; cur->light_sum -= cur->r->sum; } cur->fetch(); prev = cur; } splay(); return prev; } void link(Node *parent) { parent->expose(); expose(); p = parent; cancel = parent->added; p->r = this; p->fetch(); } void cut() { expose(); l->flush(); l->p = NONE; l = NONE; fetch(); } void evert() { expose(); rev ^= 1; flush(); } #undef l #undef r }; Node *NONE = new Node; template<class T> std::vector<std::pair<int, int> > random_tree(int n, T &rnd) { std::vector<std::pair<int, int> > res; for (int i = 1; i < n; i++) res.push_back({std::uniform_int_distribution<>(0, i - 1)(rnd), i}); int perm[n]; std::iota(perm, perm + n, 0); std::shuffle(perm, perm + n, rnd); for (auto &i : res) { if (std::uniform_int_distribution<>(0, 1)(rnd)) std::swap(i.first, i.second); i.first = perm[i.first]; i.second = perm[i.second]; } return res; } struct Fast { std::vector<Node> nodes; Fast (const std::vector<std::pair<int, int> > &hens) { int n = hens.size() + 1; nodes.resize(n); for (auto &i : nodes) i.fetch(); for (auto i : hens) { nodes[i.first].evert(); nodes[i.first].link(&nodes[i.second]); } } void add(int v, int p) { nodes[p].evert(); nodes[v].cut(); nodes[v].add(1); nodes[v].link(&nodes[p]); } int sum(int v, int p) { nodes[p].evert(); nodes[v].cut(); int res = nodes[v].sum; nodes[v].link(&nodes[p]); return res; } }; struct Gu { std::vector<std::vector<int> > hen; std::vector<int> val; Gu (const std::vector<std::pair<int, int> > &hens) { int n = hens.size() + 1; hen.resize(n); for (auto i : hens) { hen[i.first].push_back(i.second); hen[i.second].push_back(i.first); } val.resize(n); } void add(int i, int p) { val[i]++; for (auto j : hen[i]) if (j != p) add(j, i); } int sum(int i, int p) { int res = val[i]; for (auto j : hen[i]) if (j != p) res += sum(j, i); return res; } }; std::random_device rnd_dev; std::mt19937 rnd(rnd_dev() ^ time(NULL)); bool random_check(int n, int q) { auto tree = random_tree(n, rnd); struct Query { int type; int v; int p; }; std::vector<int> hen[n]; for (auto i : tree) hen[i.first].push_back(i.second), hen[i.second].push_back(i.first); std::vector<Query> qs; for (int i = 0; i < q; i++) { int t = std::uniform_int_distribution<>(1, 2)(rnd); int v = std::uniform_int_distribution<>(0, n - 1)(rnd); int p = hen[v][std::uniform_int_distribution<>(0, hen[v].size() - 1)(rnd)]; qs.push_back({t, v, p}); } Gu gu(tree); Fast fast(tree); for (int i = 0; i < q; i++) { if (qs[i].type == 1) { gu.add(qs[i].v, qs[i].p); fast.add(qs[i].v, qs[i].p); } else { int r0 = gu.sum(qs[i].v, qs[i].p); int r1 = fast.sum(qs[i].v, qs[i].p); if (r0 != r1) { std::cerr << "!!! FAILED !!! at query #" << i << std::endl; std::cerr << "correct:" << r0 << " wrong:" << r1 << std::endl; std::cerr << n << " " << q << std::endl; for (auto i : tree) std::cerr << i.first << " " << i.second << std::endl; for (auto i : qs) std::cerr << i.type << " " << i.v << " " << i.p << std::endl; return false; } } } return true; } int main() { int n = ri(); int q = ri(); Node nodes[n]; for (auto &i : nodes) i.val = ri(), i.fetch(); for (int i = 1; i < n; i++) { int a = ri(); int b = ri(); nodes[a].evert(); nodes[a].link(nodes + b); } auto to_str = [&] (Node *node) -> std::string { if (node == NONE) return "NONE"; else return "#" + std::to_string(node - nodes); }; auto debug = [&] (Node &node) { std::cerr << to_str(&node) << " p:" << to_str(node.p) << " [" << to_str(node.ch[0]) << "," << to_str(node.ch[1]) << "] rev:" << node.rev << std::endl; std::cerr << " val:" << node.val << " sum:" << node.sum << " added:" << node.added << " cancel:" << node.cancel << " lsum:" << node.light_sum << " lsize:" << node.light_size << " size:" << node.size << std::endl; }; for (int i = 0; i < q; i++) { int t = ri(); if (t == 0) { int r0 = ri(); int r1 = ri(); nodes[r0].evert(); nodes[r1].cut(); r0 = ri(); r1 = ri(); nodes[r0].evert(); nodes[r0].link(nodes + r1); } else { int v = ri(); int p = ri(); nodes[p].evert(); nodes[v].cut(); if (t == 1) nodes[v].add(ri()); else printf("%" PRId64 "\n", nodes[v].sum); nodes[v].link(nodes + p); } } return 0; }