CF #457 D. Jamie and To-do List
http://codeforces.com/problemset/problem/916/D
誤読さえしなければ解法は自明。永続multiset と永続配列があれば良い。永続multisetは使いまわせそうな形で再実装した。内部的には LLRBtree を用いている。
ポインタ型が64bitというのが厄介で、nodeにポインタ型を持たせるとMLEする。(CFは32bit環境らしい。なんでMLEしたんだ…?)
#include <iostream> #include <algorithm> #include <vector> #include <map> #include <string> using namespace std; template<class T> struct MSnode { int l = 0; int r = 0; int sz = 0; int cnt = 0; T key; bool c; }; template<class T> class MultiSet { using node_t = int; static int curr; static vector<MSnode<T>> xs; node_t root = 0; public: void insert(T key) { root = insert(root, key); xs[root].c = false; } void erase(T key) { root = erase(root, key); xs[root].c = false; } int countLT(T key) { return countLT(root, key); } private: node_t fix(node_t x) { xs[x].sz = xs[xs[x].l].sz + xs[xs[x].r].sz + xs[x].cnt; return x; } node_t create() { if (xs.size() <= curr) { xs.resize(xs.size() * 2); } xs[curr].sz = 1; xs[curr].cnt = 1; xs[curr].c = true; return curr++; } node_t create(T key) { int x = create(); xs[x].key = key; return x; } node_t clone(node_t x) { node_t y = create(); xs[y] = xs[x]; return y; } node_t create(node_t x, node_t l, node_t r, bool c) { x = clone(x); xs[x].l = l; xs[x].r = r; xs[x].c = c; return fix(x); } node_t create(node_t x, node_t l, node_t r) { x = clone(x); xs[x].l = l; xs[x].r = r; return fix(x); } node_t create(node_t x, bool c) { x = clone(x); xs[x].c = c; return x; } node_t insert(node_t x, T key) { if (!x) return create(key); if (key == xs[x].key) { x = clone(x); xs[x].cnt++; x = fix(x); } else if (key < xs[x].key) { x = create(x, insert(xs[x].l, key), xs[x].r); } else { x = create(x, xs[x].l, insert(xs[x].r, key)); } if (xs[xs[x].r].c) { x = create(xs[x].r, create(x, xs[x].l, xs[xs[x].r].l, true), xs[xs[x].r].r, xs[x].c); } if (xs[xs[x].l].c && xs[xs[xs[x].l].l].c) { x = create(xs[x].l, create(xs[xs[x].l].l, false), create(x, xs[xs[x].l].r, xs[x].r, false), true); } return x; } node_t erase(node_t x, T key) { if (!x) return 0; if (key == xs[x].key) { x = clone(x); xs[x].cnt--; x = fix(x); } else if (key < xs[x].key) { x = create(x, erase(xs[x].l, key), xs[x].r); } else { x = create(x, xs[x].l, erase(xs[x].r, key)); } return x; } int countLT(node_t x, T key) { if (!x) return 0; int res = 0; if (xs[x].key < key) { res += xs[xs[x].l].sz; res += xs[x].cnt; res += countLT(xs[x].r, key); } else { res += countLT(xs[x].l, key); } return res; } }; template<class T> int MultiSet<T>::curr = 1; template<class T> vector<MSnode<T>> MultiSet<T>::xs(1); template<class T> class Array { public: Array() {} Array(int n) { h = 0; for (int i = 1; i < n; i *= 16) h += 4; } T *mut_get(int k) { auto p = mutable_get(k, root, 0, h); root = p.first; return &p.second->value; } T get(int k) { return immutable_get(k, root, 0, h); } private: struct node { node *ch[16] = {}; T value; node() {} node(T value) : value(value) {} }; int h; node *root = nullptr; T immutable_get(int a, node *x, int l, int d) { if (!x) return T(); if (d == 0) return x->value; int id = a - l >> d - 4; return immutable_get(a, x->ch[id], l + (id << d - 4), d - 4); } pair<node *, node *> mutable_get(int a, node *x, int l, int d) { x = x ? new node(*x) : new node(); if (d == 0) return { x, x }; int id = a - l >> d - 4; auto p = mutable_get(a, x->ch[id], l + (id << d - 4), d - 4); x->ch[id] = p.first; return { x, p.second }; } }; int main() { int q; cin >> q; vector<pair<MultiSet<int>, Array<int>>> his(q + 1); MultiSet<int> st; Array<int> ar(q); his[0] = make_pair(st, ar); map<string, int> stoi; auto get_id = [&](string key) { if (!stoi.count(key)) { int t = stoi.size(); stoi[key] = t; } return stoi[key]; }; for (int ii = 1; ii <= q; ii++) { char type[30]; char key[30]; scanf("%s", type); if (type[0] == 's') { int v; scanf("%s %d", key, &v); int k = get_id(key); int pv = ar.get(k); if (pv != 0) { st.erase(pv); } *ar.mut_get(k) = v; st.insert(v); } else if (type[0] == 'r') { scanf("%s", key); int k = get_id(key); int v = ar.get(k); if (v != 0) { st.erase(v); *ar.mut_get(k) = 0; } } else if (type[0] == 'q') { scanf("%s", key); int k = get_id(key); int v = ar.get(k); if (v != 0) { printf("%d\n", st.countLT(v)); } else { printf("-1\n"); } fflush(stdout); } else { int d; scanf("%d", &d); tie(st, ar) = his[ii - d - 1]; } his[ii] = make_pair(st, ar); } }