ARC 089 D. Checker
https://beta.atcoder.jp/contests/arc089/tasks/arc089_b
この解説では、座標を(横,縦)で表している。
$(x,y)$ が黒であることと $(x \bmod 2K, y \bmod 2K)$ が黒であることは同値である。
$(x,y)$ が白であることと $(x+K,y)$ が黒であることは同値である。
#include <iostream> #include <algorithm> #include <vector> using namespace std; int main() { int n, k; cin >> n >> k; vector<vector<int>> s(4 * k + 1, vector<int>(4 * k + 1)); for (int i = 0; i < n; i++) { int x, y; char c; cin >> x >> y >> c; if (c == 'W') { x += k; } x %= 2 * k; y %= 2 * k; s[x + 1][y + 1]++; s[x + 2 * k + 1][y + 1]++; s[x + 1][y + 2 * k + 1]++; s[x + 2 * k + 1][y + 2 * k + 1]++; } for (int i = 0; i + 1 < s.size(); i++) { for (int j = 0; j + 1 < s.size(); j++) { s[i + 1][j + 1] += s[i + 1][j] + s[i][j + 1] - s[i][j]; } } auto sum = [&](int y, int x, int n) { return s[y + n][x + n] - s[y + n][x] - s[y][x + n] + s[y][x]; }; int ans = 0; for (int i = 0; i < 2 * k; i++) { for (int j = 0; j < 2 * k; j++) { int s1 = sum(i, j, k); int s2 = sum(i + k, j + k, k); ans = max(ans, s1 + s2); } } cout << ans << endl; }
図1のコード
\documentclass[dvipdfmx,margin=1cm]{standalone} \usepackage{newtxtext} \usepackage{newtxmath} \usepackage{tikz} \usetikzlibrary{calc} \usetikzlibrary{patterns} \begin{document} \begin{tikzpicture}[yscale=-1, foo/.style={circle,inner sep=0.07cm}, ] \foreach \x/\y in {2/0, 6/0, 0/2, 4/2, 2/4, 6/4, 0/6, 4/6} { \draw[fill=black!10] (\x,\y) rectangle ++(2,2); } \draw[help lines] (-0.5,-0.5) grid (8.5,8.5); \draw[step=2] (-0.5,-0.5) grid (8.5,8.5); \draw[ultra thick] (0,0) rectangle (4,4); \node[fill=black,foo] (a) at ($(0,3) +(.5,.5)$) {}; \node[fill=black,foo] (b) at ($(4,7) +(.5,.5)$) {}; \draw[-latex,very thick] (b) to node [auto,xshift=-0.4cm] {$\bmod 2K$} (a); \node[fill=white,draw=black,foo] (c) at ($(5,1) +(.5,.5)$) {}; \node[fill=black,foo] (d) at ($(7,1) +(.5,.5)$) {}; \node[fill=black,foo] (e) at ($(3,1) +(.5,.5)$) {}; \draw[-latex,very thick] (c) to node [auto] {$+K$} (d); \draw[-latex,very thick] (d) to [bend right=30] node [auto,xshift=-0.1cm] {$\bmod 2K$} (e); \draw (0,0) to [very thick,bend right=20] node [auto] {$2K$} (4,0); \draw (0,0) to [very thick,bend left=20] node [auto,swap] {$2K$} (0,4); \end{tikzpicture} \end{document}
図 2 のコード
% pgfmanual v3.0.1a \documentclass[dvipdfmx,margin=1cm]{standalone} \usepackage{newtxtext} \usepackage{newtxmath} \usepackage{tikz} \usetikzlibrary{calc} \usetikzlibrary{patterns} \begin{document} \begin{tikzpicture}[yscale=-1] \begin{scope}[shift={(0,2)},local bounding box=scope1] \draw [help lines] (-0.5,-0.5) grid (4.5,4.5); \draw [step=2] (-0.5,-0.5) grid (4.5,4.5); \draw [very thick] (0,0) rectangle (4,4); \foreach \dy/\dx in {0/0} { \foreach \x/\y in {1/0, 3/1, 2/3}{ \draw [fill=black,circle] ($(\x,\y) + (.5,.5) + (\dx,\dy)$) circle[radius=0.1cm]; } } % pattern=crosshatch dots -- See pgfmanual p.666 \draw [ultra thick,pattern=crosshatch dots] (1,2) rectangle (3,4); \draw [ultra thick,pattern=crosshatch dots] (0,0) rectangle (1,2); \draw [ultra thick,pattern=crosshatch dots] (3,0) rectangle (4,2); \draw [-latex,line width=0.1cm] (0.5,1.5) -- (1,2); \end{scope} \begin{scope}[shift={(10,0)},local bounding box=scope2] \draw [fill=black!20] (0,0) rectangle (4,4); \draw [fill=red!20] (4,0) rectangle (8,4); \draw [fill=blue!20] (0,4) rectangle (4,8); \draw [fill=green!20] (4,4) rectangle (8,8); \draw [help lines] (-0.5,-0.5) grid (8.5,8.5); \draw [step=2] (-0.5,-0.5) grid (8.5,8.5); \draw [very thick] (0,0) rectangle (4,4); \foreach \dy/\dx in {0/0, 0/4, 4/0, 4/4} { \foreach \x/\y in {1/0, 3/1, 2/3}{ \draw [fill=black,circle] ($(\x,\y) + (.5,.5) + (\dx,\dy)$) circle[radius=0.1cm]; } } \draw [-latex,ultra thick] (2,0) to [bend right=30] node [auto] {Copy} (6,0); \draw [-latex,ultra thick] (0,2) to [bend left=30] node [auto,swap] {Copy} (0,6); \draw [-latex,ultra thick] (2,8) to [bend left=30] node [auto,swap] {Copy} (6,8); \draw [ultra thick,pattern=crosshatch dots] (1,2) rectangle (3,4); \draw [ultra thick,pattern=crosshatch dots] (3,4) rectangle (5,6); \draw [-latex,line width=0.1cm] (0.5,1.5) -- (1,2); \end{scope} \draw [-latex,ultra thick] ($(scope1.east)+(0.5,0)$) -- node [auto] {Transform} ($(scope2.west)-(0.5,0)$); \end{tikzpicture} \end{document}
CF #458 E. Palindromes in a Tree
sz[v] >= x / 2
ではなくsz[v] > x / 2
が正しい)。計算量に影響はないです。2018-01-22 21:25http://codeforces.com/contest/914/problem/E
重心分解による分割統治法で解く。重心を c としたとき、回文パスは c を通るものと通らないものに分けることができる。c を通らないものに関しては部分問題で処理することにして、c を通るものに関して処理する。始点が u であるようなパスがいくつあるかが数えられるので、パスu-cにパターン数を加算すると良い。
回文かどうかの判定はビットマスクを使うとやりやすい。
#include <iostream> #include <algorithm> #include <vector> #include <string> #include <map> using namespace std; const int N = 2e5; vector<int> g[N]; bool used[N]; int sz[N]; int a[N]; long long ans[N]; void dfs2(int u, int p) { sz[u] = 1; for (int v : g[u]) if (v != p && !used[v]) { dfs2(v, u); sz[u] += sz[v]; } } int dfs3(int u, int p, int x) { for (int v : g[u]) if (v != p && !used[v]) { // if (sz[v] >= x / 2) { if (sz[v] > x / 2) { return dfs3(v, u, x); } } return u; } long long mp[1 << 20]; void dfs4(int u, int p, int val) { val ^= a[u]; mp[val]++; for (int v : g[u]) if (v != p && !used[v]) { dfs4(v, u, val); } } void dfs5(int u, int p, int val) { val ^= a[u]; mp[val]--; for (int v : g[u]) if (v != p && !used[v]) { dfs5(v, u, val); } } long long dfs6(int u, int p, int val) { val ^= a[u]; long long cnt = mp[val]; for (int i = 0; i < 20; i++) { // (2^i) xor val cnt += mp[(1 << i) ^ val]; } for (int v : g[u]) if (v != p && !used[v]) { cnt += dfs6(v, u, val); } ans[u] += cnt; return cnt; } void dfs(int u) { dfs2(u, -1); u = dfs3(u, -1, sz[u]); dfs4(u, -1, 0); used[u] = true; long long cen = mp[0]; for (int i = 0; i < 20; i++) { cen += mp[1 << i]; } for (int v : g[u]) if (!used[v]) { dfs5(v, u, a[u]); cen += dfs6(v, u, 0); dfs4(v, u, a[u]); } cen /= 2; ans[u] += cen; dfs5(u, -1, 0); for (int v : g[u]) if (!used[v]) { dfs(v); } } int main() { int n; cin >> n; for (int i = 1; i < n; i++) { int u, v; scanf("%d %d", &u, &v); u--; v--; g[u].push_back(v); g[v].push_back(u); } string s; cin >> s; for (int i = 0; i < n; i++) { a[i] = 1 << (s[i] - 'a'); } dfs(0); for (int i = 0; i < n; i++) { printf("%lld ", ans[i] + 1); } }
図のコード。TikZ
% pgfmanual version 3.0.1a \documentclass[dvipdfmx,margin=1cm]{standalone} \usepackage{newtxtext} \usepackage{newtxmath} \usepackage{tikz} \usetikzlibrary{scopes} \usetikzlibrary{calc} % pgfmanual.pdf p.142 \usetikzlibrary{shadows} % pgfmanual.pdf p.689 \usetikzlibrary{shapes.geometric} \begin{document} \begin{tikzpicture}[ foo/.style={fill=black,draw=none,text=white,drop shadow,shape=circle}, bar/.style={fill=white,draw=black,text=white,drop shadow,shape=isosceles triangle}, baz/.style={fill=red,draw=none,text=white,drop shadow,shape=circle}, ] \begin{scope}[local bounding box=scope1] \node [foo] (a) at (0,0) {$c$}; \node [bar,shape border uses incircle,shape border rotate=180,minimum size=2cm] (A) at (0:3) {}; \node [bar,shape border uses incircle,shape border rotate=300,minimum size=2cm] (B) at (120:3) {}; \node [bar,shape border uses incircle,shape border rotate=420,minimum size=2cm] (C) at (240:3) {}; % See pgfmanual.pdf p.703 \draw (a) -- (A.apex); \draw (a) -- (B.apex); \draw (a) -- (C.apex); \node [baz] (u) at (B.center) {$u$}; \node [baz] (v) at (C.center) {$v$}; \draw [-latex,draw=black!80,line width=0.1cm] (u) -- (B.apex) .. controls (120:0.5) and (240:0.5) .. (C.apex) -- (v); \end{scope} \begin{scope}[shift={($(scope1.east)-(scope1.west)+(3.5cm,0)$)},local bounding box=scope2] \node [foo] (a) at (0,0) {$c$}; \node [bar,shape border uses incircle,shape border rotate=180,minimum size=2cm] (A) at (0:3) {}; \node [bar,shape border uses incircle,shape border rotate=300,minimum size=2cm] (B) at (120:3) {}; \node [bar,shape border uses incircle,shape border rotate=420,minimum size=2cm] (C) at (240:3) {}; \draw (a) -- (A.apex); \draw (a) -- (B.apex); \draw (a) -- (C.apex); \node [baz] (u) at (B.center) {$u$}; \node [baz] (v) at (C.center) {$v$}; \draw [-latex,draw=black!80,line width=0.1cm] (u) -- (a); \draw [-latex,draw=black!80,line width=0.1cm] (a) -- (v); \end{scope} % \draw (scope1.south west) rectangle (scope1.north east); % \draw (scope2.south west) rectangle (scope2.north east); % \node [draw,circle] at (scope1.east) {}; % \node [draw,circle] at (scope2.west) {}; \draw [-latex,line width=0.05cm] ($(scope1.east) + (0.5cm,0)$) -- node [auto] {Decompose} ($(scope2.west) + (-0.5cm,0)$); \end{tikzpicture} \end{document}
感想
shape border uses incircle が結構重要っぽくて、これがないと90°単位でしか回転できないらしい。
CF #458 G. Sum the Fibonacci
AND, OR, XOR 畳み込みを使う。ANDは分割統治、ORは高速ゼータ変換と高速メビウス変換、XORは高速アダマール変換を使うことで処理できる。
n=2^17 としたとき、ANDがO(n log n)、ORがO(n log^2 n)、XORがO(n log n)になっている。OR も O(n log n) で行けるのだろうか?
#include <iostream> #include <algorithm> #include <vector> #include <string> #include <functional> #include <ctime> using namespace std; const int mod = 1e9 + 7; struct Modint { int n; Modint(int n = 0) : n(n) {} }; Modint operator+(Modint a, Modint b) { return Modint((a.n += b.n) >= mod ? a.n - mod : a.n); } Modint operator-(Modint a, Modint b) { return Modint((a.n -= b.n) < 0 ? a.n + mod : a.n); } Modint operator*(Modint a, Modint b) { return Modint(1LL * a.n * b.n % mod); } Modint &operator+=(Modint &a, Modint b) { return a = a + b; } Modint &operator-=(Modint &a, Modint b) { return a = a - b; } Modint &operator*=(Modint &a, Modint b) { return a = a * b; } Modint modpow(Modint a, long long b) { Modint res = 1; while (b > 0) { if (b & 1) res *= a; a *= a; b >>= 1; } return res; } void fast_zeta_transform(vector<Modint> &f) { for (int i = 0; (1 << i) < f.size(); i++) { for (int j = 0; j < f.size(); j++) { if (j & 1 << i) { f[j] += f[j ^ 1 << i]; } } } } void fast_mobius_transform(vector<Modint> &f) { for (int i = 0; (1 << i) < f.size(); i++) { for (int j = 0; j < f.size(); j++) { if (j & 1 << i) { f[j] -= f[j ^ 1 << i]; } } } } void hadamard_transform(std::vector<Modint> &a, int l, int r) { if (r - l == 1) return; int n = (r - l) / 2; int m = (l + r) / 2; hadamard_transform(a, l, m); hadamard_transform(a, m, r); for (int i = 0; i < n; i++) { Modint x = a[l + i], y = a[m + i]; a[l + i] = x + y; a[m + i] = x - y; } } vector<Modint> xor_convolution(vector<Modint> a) { hadamard_transform(a, 0, a.size()); vector<Modint> res(a.size()); for (int i = 0; i < a.size(); i++) { res[i] = a[i] * a[i]; } hadamard_transform(res, 0, res.size()); Modint inv = modpow(res.size(), mod - 2); for (int i = 0; i < res.size(); i++) { res[i] *= inv; } return res; } int bitcnt[1 << 17]; vector<Modint> or_convolution(vector<Modint> a) { vector<Modint> res(a.size()); vector<vector<Modint>> cnt(18, vector<Modint>(a.size())); for (int i = 0; i < a.size(); i++) { cnt[bitcnt[i]][i] += a[i]; } for (int i = 0; i < 18; i++) { fast_zeta_transform(cnt[i]); } for (int i = 0; i <= 17; i++) { vector<Modint> tmp(a.size()); for (int j = 0; j <= i; j++) { int k = i - j; for (int l = 0; l < a.size(); l++) { tmp[l] += cnt[j][l] * cnt[k][l]; } } fast_mobius_transform(tmp); for (int j = 0; j < a.size(); j++) { if (bitcnt[j] == i) { res[j] += tmp[j]; } } } return res; } vector<Modint> and_convolution(vector<Modint> a, vector<Modint> b) { function<void(int, int)> f = [&](int l, int r) { const int n = r - l; if (n == 1) { a[l] *= b[l]; return; } const int m = (l + r) / 2; for (int i = 0; i < n / 2; i++) { a[l + i] += a[m + i]; b[l + i] += b[m + i]; } f(l, m); f(m, r); for (int i = 0; i < n / 2; i++) { a[l + i] -= a[m + i]; } }; f(0, a.size()); return a; } int main() { int n; cin >> n; for (int i = 1; i < 1 << 17; i++) { bitcnt[i] = bitcnt[i & i - 1] + 1; } vector<Modint> f(1 << 17); for (int i = 0; i < n; i++) { int s; scanf("%d", &s); f[s] += 1; } vector<Modint> fib(1 << 17); fib[0] = 0; fib[1] = 1; for (int i = 2; i < 1 << 17; i++) { fib[i] = fib[i - 1] + fib[i - 2]; } auto a = or_convolution(f); auto b = xor_convolution(f); for (int i = 0; i < a.size(); i++) { a[i] = fib[i] * a[i]; b[i] = fib[i] * b[i]; f[i] = fib[i] * f[i]; } a = and_convolution(a, b); a = and_convolution(a, f); Modint ans; for (int i = 0; i < 17; i++) { ans += a[1 << i]; } cout << ans.n << endl; }
CF #457 E. Jamie and Tree
http://codeforces.com/problemset/problem/916/E
LC-tree は根を変更するクエリと lca を求めるクエリを捌ける。ET-tree は辺の追加・削除と連結成分全体に x を一様加算するというクエリが捌ける。これらを組み合わせることで容易に解くことができる。
#include <iostream> #include <algorithm> #include <vector> #include <map> using namespace std; struct ETTree { struct node { node *l = nullptr; node *r = nullptr; node *p = nullptr; bool active; int size; long long val = 0; long long lazy = 0; long long sum = 0; node(bool active_, int v = 0) : active(active_), size(active_), sum(v), val(v) {} }; map<pair<int, int>, node *> edges; vector<node *> super; const int n; ETTree(const vector<vector<int>> &g, vector<int> a) : n(g.size()) { super.resize(n); for (int i = 0; i < n; i++) { for (int j : g[i]) { edges[{i, j}] = new node(false); } super[i] = new node(true, a[i]); } for (int i = 0; i < n; i++) { for (int j : g[i]) if (i < j) { link(i, j); } } } void link(int u, int v) { node *uv = edges[{u, v}]; node *vu = edges[{v, u}]; node *x = super[u]; node *y = super[v]; rotate(x); rotate(y); join(x, uv); join(uv, y); join(y, vu); } void cut(int u, int v) { node *uv = edges[{u, v}]; node *vu = edges[{v, u}]; rotate(uv); split_right(uv); split_left(vu); split_right(vu); } void add(int u, long long w) { put_lazy(splay(super[u]), w); } int find_size(int u) { return splay(super[u])->size; } long long access(int u) { splay(super[u]); return splay(super[u])->sum; } void rot(node *x) { node *y = x->p, *z = y->p; if (z) { if (y == z->l) z->l = x; if (y == z->r) z->r = x; } x->p = z; y->p = x; if (x == y->l) { y->l = x->r; x->r = y; if (y->l) y->l->p = y; } else { y->r = x->l; x->l = y; if (y->r) y->r->p = y; } pull(y); } void push_above(node *x) { if (x->p) push_above(x->p); push(x); } node *splay(node *x) { push_above(x); while (x->p) { node *y = x->p, *z = y->p; if (z) rot((x == y->l) == (y == z->l) ? y : x); rot(x); } pull(x); return x; } void pull(node *x) { x->size = x->active; if (x->l) x->size += x->l->size; if (x->r) x->size += x->r->size; // assert(x->lazy == 0); x->sum = 0; if (x->active) { x->sum += x->val; } if (x->l) x->sum += x->l->sum; if (x->r) x->sum += x->r->sum; } void put_lazy(node *x, long long v) { x->lazy += v; if (x->active) { x->val += v; } x->sum += v * x->size; } void push(node *x) { if (x->l) put_lazy(x->l, x->lazy); if (x->r) put_lazy(x->r, x->lazy); x->lazy = 0; } // [_ _ x _ _] + [_ _ y _ _] -> [_ _ x _ _ _ _ y _ _] void join(node *x, node *y) { if (!x || !y) return; splay(x); while (x->r) x = x->r; splay(x); splay(y); // assert(x->lazy == 0); x->r = y; y->p = x; pull(x); } // [_ _ x _ _] -> [_ _] [x _ _] node *split_left(node *x) { splay(x); // assert(x->lazy == 0); node *y = x->l; if (y) { y->p = x->l = nullptr; pull(x); } return y; } // [_ _ x _ _] -> [_ _ x] [_ _] node *split_right(node *x) { splay(x); // assert(x->lazy == 0); node *y = x->r; if (y) { y->p = x->r = nullptr; pull(x); } return y; } // [0 1 2 x 3 4 5] -> [x 3 4 5 0 1 2] void rotate(node *x) { join(x, split_left(x)); } }; struct node { node *l = nullptr; node *r = nullptr; node *p = nullptr; bool rev = false; }; bool is_root(node *x) { return !x->p || (x != x->p->l && x != x->p->r); } void rot(node *x) { node *y = x->p, *z = y->p; if (z) { if (y == z->l) z->l = x; if (y == z->r) z->r = x; } x->p = z; y->p = x; if (x == y->l) { y->l = x->r; x->r = y; if (y->l) y->l->p = y; } else { y->r = x->l; x->l = y; if (y->r) y->r->p = y; } } void reverse(node *x) { swap(x->l, x->r); x->rev ^= true; } void push(node *x) { if (x->rev) { if (x->l) reverse(x->l); if (x->r) reverse(x->r); } x->rev = false; } void push_above(node *x) { if (x->p) push_above(x->p); push(x); } void splay(node *x) { while (!is_root(x)) { node *y = x->p; if (!is_root(y)) { node *z = y->p; rot((x == y->l) == (y == z->l) ? y : x); } rot(x); } } void splayLC(node *x) { node *tmp = x; push_above(x); for (node *r = nullptr; x; x = x->p) { splay(x); x->r = r; r = x; } splay(tmp); } void reroot(node *x) { splayLC(x); reverse(x); } void link(node *c, node *p) { reroot(c); splayLC(p); p->r = c; c->p = p; } void cut(node *x) { splayLC(x); x->l->p = nullptr; x->l = nullptr; } node *lca(node *x, node *y) { splayLC(y); splayLC(x); bool same = false; node *l = y; while (y != nullptr) { if (is_root(y) && y->p) { l = y->p; } if (y == x->r) return x; if (x == y) same = true; y = y->p; } if (!same) return nullptr; return l; } node *get_parent(node *x) { splayLC(x); if (!x->l) { return nullptr; } x = x->l; while (x->r) { x = x->r; push(x); } splayLC(x); return x; } int input() { int n; scanf("%d", &n); return n; } int main() { int n, q; cin >> n; cin >> q; vector<node *> lc(n); map<node *, int> mp; mp[nullptr] = -1; vector<vector<int>> g(n); for (int i = 0; i < n; i++) { lc[i] = new node(); mp[lc[i]] = i; } vector<int> a(n); for (int i = 0; i < n; i++) { a[i] = input(); } for (int i = 0; i < n - 1; i++) { int u = input() - 1; int v = input() - 1; g[u].push_back(v); g[v].push_back(u); link(lc[u], lc[v]); } reroot(lc[0]); ETTree et(g, a); while (q--) { int type = input(); if (type == 1) { int v = input() - 1; reroot(lc[v]); } else if (type == 2) { int u = input() - 1; int v = input() - 1; int x = input(); int l = mp[lca(lc[u], lc[v])]; int p = mp[get_parent(lc[l])]; if (p != -1) et.cut(l, p); et.add(l, x); if (p != -1) et.link(l, p); } else { int v = input() - 1; int p = mp[get_parent(lc[v])]; if (p != -1) et.cut(v, p); long long ans = et.access(v); if (p != -1) et.link(v, p); printf("%lld\n", ans); } } }
恐らく普通の解。
#include <iostream> #include <algorithm> #include <vector> #include <map> #include <functional> using namespace std; int input() { int n; scanf("%d", &n); return n; } struct HLD { using Tree = std::vector<std::vector<int>>; std::vector<int> parent, head, vid, inv; HLD(const Tree &g) : parent(g.size()), head(g.size()), vid(g.size()), inv(g.size()) { int k = 0; std::vector<int> size(g.size(), 1); std::function<void(int, int)> dfs = [&](int u, int p) { for (int v : g[u]) if (v != p) dfs(v, u), size[u] += size[v]; }; std::function<void(int, int, int)> dfs2 = [&](int u, int p, int h) { parent[u] = p; head[u] = h; vid[u] = k++; inv[vid[u]] = u; for (int v : g[u]) if (v != p && size[u] < size[v] * 2) dfs2(v, u, h); for (int v : g[u]) if (v != p && size[u] >= size[v] * 2) dfs2(v, u, v); }; dfs(0, -1); dfs2(0, -1, 0); } int lca(int a, int b, int c) { while (true) { if (vid[a] < vid[b]) swap(a, b); if (vid[a] < vid[c]) swap(a, c); if (vid[b] < vid[c]) swap(b, c); if (head[a] == head[b]) return b; a = parent[head[a]]; } } int find_parent(int u, int r) { if (u == r) return -1; while (vid[u] < vid[r]) { if (head[u] == head[r]) return inv[vid[u] + 1]; if (parent[head[r]] == u) return head[r]; r = parent[head[r]]; } return parent[u]; } }; struct EulerTour { std::vector<int> l; std::vector<int> r; EulerTour(const std::vector<std::vector<int>> &g) { const int n = g.size(); l.resize(n); r.resize(n); int k = 0; std::function<void(int, int)> dfs = [&](int u, int p) { l[u] = k++; for (int v : g[u]) if (v != p) dfs(v, u); r[u] = k; }; dfs(0, -1); } }; template<class T> struct BIT { vector<T> s0; vector<T> s1; BIT(int n) : s0(n + 2), s1(n + 2) {} void add0(int k, T v) { for (int i = k + 1; i < s0.size(); i += i & -i) { s0[i] -= v * k; s1[i] += v; } } void add(int l, int r, T v) { add0(l, v); add0(r, -v); } T sum0(int k) { T t0 = 0; T t1 = 0; for (int i = k + 1; i > 0; i -= i & -i) { t0 += s0[i]; t1 += s1[i]; } return t0 + t1 * k; } T sum(int l, int r) { return sum0(r) - sum0(l); } T sumall() { return sum0(s0.size() - 2); } }; int main() { int n, q; cin >> n; cin >> q; vector<vector<int>> g(n); vector<int> a(n); for (int i = 0; i < n; i++) { a[i] = input(); } for (int i = 0; i < n - 1; i++) { int u = input() - 1; int v = input() - 1; g[u].push_back(v); g[v].push_back(u); } EulerTour et(g); HLD hld(g); int root = 0; long long all = 0; BIT<long long> bit(n); for (int i = 0; i < n; i++) { bit.add(et.l[i], et.l[i] + 1, a[i]); } while (q--) { int type = input(); if (type == 1) { int v = input() - 1; root = v; } else if (type == 2) { int u = input() - 1; int v = input() - 1; int x = input(); int l = hld.lca(u, v, root); int p = hld.find_parent(l, root); if (p == -1) { all += x; } else if (p == hld.parent[l]) { bit.add(et.l[l], et.r[l], x); } else { bit.add(et.l[p], et.r[p], -x); all += x; } } else { int v = input() - 1; int p = hld.find_parent(v, root); long long ans = 0; if (p == -1) { ans = all * n + bit.sumall(); } else if (p == hld.parent[v]) { ans = all * (et.r[v] - et.l[v]) + bit.sum(et.l[v], et.r[v]); } else { ans += all * n + bit.sumall(); ans -= all * (et.r[p] - et.l[p]) + bit.sum(et.l[p], et.r[p]); } printf("%lld\n", ans); } } }
感想
ET-tree が遅い。
CF #457 D. Jamie and To-do List
http://codeforces.com/problemset/problem/916/D
誤読さえしなければ解法は自明。永続multiset と永続配列があれば良い。永続multisetは使いまわせそうな形で再実装した。内部的には LLRBtree を用いている。
ポインタ型が64bitというのが厄介で、nodeにポインタ型を持たせるとMLEする。(CFは32bit環境らしい。なんでMLEしたんだ…?)
#include <iostream> #include <algorithm> #include <vector> #include <map> #include <string> using namespace std; template<class T> struct MSnode { int l = 0; int r = 0; int sz = 0; int cnt = 0; T key; bool c; }; template<class T> class MultiSet { using node_t = int; static int curr; static vector<MSnode<T>> xs; node_t root = 0; public: void insert(T key) { root = insert(root, key); xs[root].c = false; } void erase(T key) { root = erase(root, key); xs[root].c = false; } int countLT(T key) { return countLT(root, key); } private: node_t fix(node_t x) { xs[x].sz = xs[xs[x].l].sz + xs[xs[x].r].sz + xs[x].cnt; return x; } node_t create() { if (xs.size() <= curr) { xs.resize(xs.size() * 2); } xs[curr].sz = 1; xs[curr].cnt = 1; xs[curr].c = true; return curr++; } node_t create(T key) { int x = create(); xs[x].key = key; return x; } node_t clone(node_t x) { node_t y = create(); xs[y] = xs[x]; return y; } node_t create(node_t x, node_t l, node_t r, bool c) { x = clone(x); xs[x].l = l; xs[x].r = r; xs[x].c = c; return fix(x); } node_t create(node_t x, node_t l, node_t r) { x = clone(x); xs[x].l = l; xs[x].r = r; return fix(x); } node_t create(node_t x, bool c) { x = clone(x); xs[x].c = c; return x; } node_t insert(node_t x, T key) { if (!x) return create(key); if (key == xs[x].key) { x = clone(x); xs[x].cnt++; x = fix(x); } else if (key < xs[x].key) { x = create(x, insert(xs[x].l, key), xs[x].r); } else { x = create(x, xs[x].l, insert(xs[x].r, key)); } if (xs[xs[x].r].c) { x = create(xs[x].r, create(x, xs[x].l, xs[xs[x].r].l, true), xs[xs[x].r].r, xs[x].c); } if (xs[xs[x].l].c && xs[xs[xs[x].l].l].c) { x = create(xs[x].l, create(xs[xs[x].l].l, false), create(x, xs[xs[x].l].r, xs[x].r, false), true); } return x; } node_t erase(node_t x, T key) { if (!x) return 0; if (key == xs[x].key) { x = clone(x); xs[x].cnt--; x = fix(x); } else if (key < xs[x].key) { x = create(x, erase(xs[x].l, key), xs[x].r); } else { x = create(x, xs[x].l, erase(xs[x].r, key)); } return x; } int countLT(node_t x, T key) { if (!x) return 0; int res = 0; if (xs[x].key < key) { res += xs[xs[x].l].sz; res += xs[x].cnt; res += countLT(xs[x].r, key); } else { res += countLT(xs[x].l, key); } return res; } }; template<class T> int MultiSet<T>::curr = 1; template<class T> vector<MSnode<T>> MultiSet<T>::xs(1); template<class T> class Array { public: Array() {} Array(int n) { h = 0; for (int i = 1; i < n; i *= 16) h += 4; } T *mut_get(int k) { auto p = mutable_get(k, root, 0, h); root = p.first; return &p.second->value; } T get(int k) { return immutable_get(k, root, 0, h); } private: struct node { node *ch[16] = {}; T value; node() {} node(T value) : value(value) {} }; int h; node *root = nullptr; T immutable_get(int a, node *x, int l, int d) { if (!x) return T(); if (d == 0) return x->value; int id = a - l >> d - 4; return immutable_get(a, x->ch[id], l + (id << d - 4), d - 4); } pair<node *, node *> mutable_get(int a, node *x, int l, int d) { x = x ? new node(*x) : new node(); if (d == 0) return { x, x }; int id = a - l >> d - 4; auto p = mutable_get(a, x->ch[id], l + (id << d - 4), d - 4); x->ch[id] = p.first; return { x, p.second }; } }; int main() { int q; cin >> q; vector<pair<MultiSet<int>, Array<int>>> his(q + 1); MultiSet<int> st; Array<int> ar(q); his[0] = make_pair(st, ar); map<string, int> stoi; auto get_id = [&](string key) { if (!stoi.count(key)) { int t = stoi.size(); stoi[key] = t; } return stoi[key]; }; for (int ii = 1; ii <= q; ii++) { char type[30]; char key[30]; scanf("%s", type); if (type[0] == 's') { int v; scanf("%s %d", key, &v); int k = get_id(key); int pv = ar.get(k); if (pv != 0) { st.erase(pv); } *ar.mut_get(k) = v; st.insert(v); } else if (type[0] == 'r') { scanf("%s", key); int k = get_id(key); int v = ar.get(k); if (v != 0) { st.erase(v); *ar.mut_get(k) = 0; } } else if (type[0] == 'q') { scanf("%s", key); int k = get_id(key); int v = ar.get(k); if (v != 0) { printf("%d\n", st.countLT(v)); } else { printf("-1\n"); } fflush(stdout); } else { int d; scanf("%d", &d); tie(st, ar) = his[ii - d - 1]; } his[ii] = make_pair(st, ar); } }