CF #457 E. Jamie and Tree
http://codeforces.com/problemset/problem/916/E
LC-tree は根を変更するクエリと lca を求めるクエリを捌ける。ET-tree は辺の追加・削除と連結成分全体に x を一様加算するというクエリが捌ける。これらを組み合わせることで容易に解くことができる。
#include <iostream> #include <algorithm> #include <vector> #include <map> using namespace std; struct ETTree { struct node { node *l = nullptr; node *r = nullptr; node *p = nullptr; bool active; int size; long long val = 0; long long lazy = 0; long long sum = 0; node(bool active_, int v = 0) : active(active_), size(active_), sum(v), val(v) {} }; map<pair<int, int>, node *> edges; vector<node *> super; const int n; ETTree(const vector<vector<int>> &g, vector<int> a) : n(g.size()) { super.resize(n); for (int i = 0; i < n; i++) { for (int j : g[i]) { edges[{i, j}] = new node(false); } super[i] = new node(true, a[i]); } for (int i = 0; i < n; i++) { for (int j : g[i]) if (i < j) { link(i, j); } } } void link(int u, int v) { node *uv = edges[{u, v}]; node *vu = edges[{v, u}]; node *x = super[u]; node *y = super[v]; rotate(x); rotate(y); join(x, uv); join(uv, y); join(y, vu); } void cut(int u, int v) { node *uv = edges[{u, v}]; node *vu = edges[{v, u}]; rotate(uv); split_right(uv); split_left(vu); split_right(vu); } void add(int u, long long w) { put_lazy(splay(super[u]), w); } int find_size(int u) { return splay(super[u])->size; } long long access(int u) { splay(super[u]); return splay(super[u])->sum; } void rot(node *x) { node *y = x->p, *z = y->p; if (z) { if (y == z->l) z->l = x; if (y == z->r) z->r = x; } x->p = z; y->p = x; if (x == y->l) { y->l = x->r; x->r = y; if (y->l) y->l->p = y; } else { y->r = x->l; x->l = y; if (y->r) y->r->p = y; } pull(y); } void push_above(node *x) { if (x->p) push_above(x->p); push(x); } node *splay(node *x) { push_above(x); while (x->p) { node *y = x->p, *z = y->p; if (z) rot((x == y->l) == (y == z->l) ? y : x); rot(x); } pull(x); return x; } void pull(node *x) { x->size = x->active; if (x->l) x->size += x->l->size; if (x->r) x->size += x->r->size; // assert(x->lazy == 0); x->sum = 0; if (x->active) { x->sum += x->val; } if (x->l) x->sum += x->l->sum; if (x->r) x->sum += x->r->sum; } void put_lazy(node *x, long long v) { x->lazy += v; if (x->active) { x->val += v; } x->sum += v * x->size; } void push(node *x) { if (x->l) put_lazy(x->l, x->lazy); if (x->r) put_lazy(x->r, x->lazy); x->lazy = 0; } // [_ _ x _ _] + [_ _ y _ _] -> [_ _ x _ _ _ _ y _ _] void join(node *x, node *y) { if (!x || !y) return; splay(x); while (x->r) x = x->r; splay(x); splay(y); // assert(x->lazy == 0); x->r = y; y->p = x; pull(x); } // [_ _ x _ _] -> [_ _] [x _ _] node *split_left(node *x) { splay(x); // assert(x->lazy == 0); node *y = x->l; if (y) { y->p = x->l = nullptr; pull(x); } return y; } // [_ _ x _ _] -> [_ _ x] [_ _] node *split_right(node *x) { splay(x); // assert(x->lazy == 0); node *y = x->r; if (y) { y->p = x->r = nullptr; pull(x); } return y; } // [0 1 2 x 3 4 5] -> [x 3 4 5 0 1 2] void rotate(node *x) { join(x, split_left(x)); } }; struct node { node *l = nullptr; node *r = nullptr; node *p = nullptr; bool rev = false; }; bool is_root(node *x) { return !x->p || (x != x->p->l && x != x->p->r); } void rot(node *x) { node *y = x->p, *z = y->p; if (z) { if (y == z->l) z->l = x; if (y == z->r) z->r = x; } x->p = z; y->p = x; if (x == y->l) { y->l = x->r; x->r = y; if (y->l) y->l->p = y; } else { y->r = x->l; x->l = y; if (y->r) y->r->p = y; } } void reverse(node *x) { swap(x->l, x->r); x->rev ^= true; } void push(node *x) { if (x->rev) { if (x->l) reverse(x->l); if (x->r) reverse(x->r); } x->rev = false; } void push_above(node *x) { if (x->p) push_above(x->p); push(x); } void splay(node *x) { while (!is_root(x)) { node *y = x->p; if (!is_root(y)) { node *z = y->p; rot((x == y->l) == (y == z->l) ? y : x); } rot(x); } } void splayLC(node *x) { node *tmp = x; push_above(x); for (node *r = nullptr; x; x = x->p) { splay(x); x->r = r; r = x; } splay(tmp); } void reroot(node *x) { splayLC(x); reverse(x); } void link(node *c, node *p) { reroot(c); splayLC(p); p->r = c; c->p = p; } void cut(node *x) { splayLC(x); x->l->p = nullptr; x->l = nullptr; } node *lca(node *x, node *y) { splayLC(y); splayLC(x); bool same = false; node *l = y; while (y != nullptr) { if (is_root(y) && y->p) { l = y->p; } if (y == x->r) return x; if (x == y) same = true; y = y->p; } if (!same) return nullptr; return l; } node *get_parent(node *x) { splayLC(x); if (!x->l) { return nullptr; } x = x->l; while (x->r) { x = x->r; push(x); } splayLC(x); return x; } int input() { int n; scanf("%d", &n); return n; } int main() { int n, q; cin >> n; cin >> q; vector<node *> lc(n); map<node *, int> mp; mp[nullptr] = -1; vector<vector<int>> g(n); for (int i = 0; i < n; i++) { lc[i] = new node(); mp[lc[i]] = i; } vector<int> a(n); for (int i = 0; i < n; i++) { a[i] = input(); } for (int i = 0; i < n - 1; i++) { int u = input() - 1; int v = input() - 1; g[u].push_back(v); g[v].push_back(u); link(lc[u], lc[v]); } reroot(lc[0]); ETTree et(g, a); while (q--) { int type = input(); if (type == 1) { int v = input() - 1; reroot(lc[v]); } else if (type == 2) { int u = input() - 1; int v = input() - 1; int x = input(); int l = mp[lca(lc[u], lc[v])]; int p = mp[get_parent(lc[l])]; if (p != -1) et.cut(l, p); et.add(l, x); if (p != -1) et.link(l, p); } else { int v = input() - 1; int p = mp[get_parent(lc[v])]; if (p != -1) et.cut(v, p); long long ans = et.access(v); if (p != -1) et.link(v, p); printf("%lld\n", ans); } } }
恐らく普通の解。
#include <iostream> #include <algorithm> #include <vector> #include <map> #include <functional> using namespace std; int input() { int n; scanf("%d", &n); return n; } struct HLD { using Tree = std::vector<std::vector<int>>; std::vector<int> parent, head, vid, inv; HLD(const Tree &g) : parent(g.size()), head(g.size()), vid(g.size()), inv(g.size()) { int k = 0; std::vector<int> size(g.size(), 1); std::function<void(int, int)> dfs = [&](int u, int p) { for (int v : g[u]) if (v != p) dfs(v, u), size[u] += size[v]; }; std::function<void(int, int, int)> dfs2 = [&](int u, int p, int h) { parent[u] = p; head[u] = h; vid[u] = k++; inv[vid[u]] = u; for (int v : g[u]) if (v != p && size[u] < size[v] * 2) dfs2(v, u, h); for (int v : g[u]) if (v != p && size[u] >= size[v] * 2) dfs2(v, u, v); }; dfs(0, -1); dfs2(0, -1, 0); } int lca(int a, int b, int c) { while (true) { if (vid[a] < vid[b]) swap(a, b); if (vid[a] < vid[c]) swap(a, c); if (vid[b] < vid[c]) swap(b, c); if (head[a] == head[b]) return b; a = parent[head[a]]; } } int find_parent(int u, int r) { if (u == r) return -1; while (vid[u] < vid[r]) { if (head[u] == head[r]) return inv[vid[u] + 1]; if (parent[head[r]] == u) return head[r]; r = parent[head[r]]; } return parent[u]; } }; struct EulerTour { std::vector<int> l; std::vector<int> r; EulerTour(const std::vector<std::vector<int>> &g) { const int n = g.size(); l.resize(n); r.resize(n); int k = 0; std::function<void(int, int)> dfs = [&](int u, int p) { l[u] = k++; for (int v : g[u]) if (v != p) dfs(v, u); r[u] = k; }; dfs(0, -1); } }; template<class T> struct BIT { vector<T> s0; vector<T> s1; BIT(int n) : s0(n + 2), s1(n + 2) {} void add0(int k, T v) { for (int i = k + 1; i < s0.size(); i += i & -i) { s0[i] -= v * k; s1[i] += v; } } void add(int l, int r, T v) { add0(l, v); add0(r, -v); } T sum0(int k) { T t0 = 0; T t1 = 0; for (int i = k + 1; i > 0; i -= i & -i) { t0 += s0[i]; t1 += s1[i]; } return t0 + t1 * k; } T sum(int l, int r) { return sum0(r) - sum0(l); } T sumall() { return sum0(s0.size() - 2); } }; int main() { int n, q; cin >> n; cin >> q; vector<vector<int>> g(n); vector<int> a(n); for (int i = 0; i < n; i++) { a[i] = input(); } for (int i = 0; i < n - 1; i++) { int u = input() - 1; int v = input() - 1; g[u].push_back(v); g[v].push_back(u); } EulerTour et(g); HLD hld(g); int root = 0; long long all = 0; BIT<long long> bit(n); for (int i = 0; i < n; i++) { bit.add(et.l[i], et.l[i] + 1, a[i]); } while (q--) { int type = input(); if (type == 1) { int v = input() - 1; root = v; } else if (type == 2) { int u = input() - 1; int v = input() - 1; int x = input(); int l = hld.lca(u, v, root); int p = hld.find_parent(l, root); if (p == -1) { all += x; } else if (p == hld.parent[l]) { bit.add(et.l[l], et.r[l], x); } else { bit.add(et.l[p], et.r[p], -x); all += x; } } else { int v = input() - 1; int p = hld.find_parent(v, root); long long ans = 0; if (p == -1) { ans = all * n + bit.sumall(); } else if (p == hld.parent[v]) { ans = all * (et.r[v] - et.l[v]) + bit.sum(et.l[v], et.r[v]); } else { ans += all * n + bit.sumall(); ans -= all * (et.r[p] - et.l[p]) + bit.sum(et.l[p], et.r[p]); } printf("%lld\n", ans); } } }
感想
ET-tree が遅い。