CF #457 D. Jamie and To-do List

http://codeforces.com/problemset/problem/916/D

誤読さえしなければ解法は自明。永続multiset と永続配列があれば良い。永続multisetは使いまわせそうな形で再実装した。内部的には LLRBtree を用いている。
ポインタ型が64bitというのが厄介で、nodeにポインタ型を持たせるとMLEする。(CFは32bit環境らしい。なんでMLEしたんだ…?)

#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#include <string>

using namespace std;

template<class T> struct MSnode {
  int l = 0;
  int r = 0;
  int sz = 0;
  int cnt = 0;
  T key;
  bool c;
};

template<class T> class MultiSet {
  using node_t = int;
  static int curr;
  static vector<MSnode<T>> xs;

  node_t root = 0;

public:
  void insert(T key) {
    root = insert(root, key);
    xs[root].c = false;
  }

  void erase(T key) {
    root = erase(root, key);
    xs[root].c = false;
  }

  int countLT(T key) {
    return countLT(root, key);
  }

private:
  node_t fix(node_t x) {
    xs[x].sz = xs[xs[x].l].sz + xs[xs[x].r].sz + xs[x].cnt;
    return x;
  }

  node_t create() {
    if (xs.size() <= curr) {
      xs.resize(xs.size() * 2);
    }
    xs[curr].sz = 1;
    xs[curr].cnt = 1;
    xs[curr].c = true;
    return curr++;
  }

  node_t create(T key) {
    int x = create();
    xs[x].key = key;
    return x;
  }

  node_t clone(node_t x) {
    node_t y = create();
    xs[y] = xs[x];
    return y;
  }

  node_t create(node_t x, node_t l, node_t r, bool c) {
    x = clone(x);
    xs[x].l = l;
    xs[x].r = r;
    xs[x].c = c;
    return fix(x);
  }

  node_t create(node_t x, node_t l, node_t r) {
    x = clone(x);
    xs[x].l = l;
    xs[x].r = r;
    return fix(x);
  }

  node_t create(node_t x, bool c) {
    x = clone(x);
    xs[x].c = c;
    return x;
  }

  node_t insert(node_t x, T key) {
    if (!x) return create(key);
    if (key == xs[x].key) {
      x = clone(x);
      xs[x].cnt++;
      x = fix(x);
    } else if (key < xs[x].key) {
      x = create(x, insert(xs[x].l, key), xs[x].r);
    } else {
      x = create(x, xs[x].l, insert(xs[x].r, key));
    }
    if (xs[xs[x].r].c) {
      x = create(xs[x].r, create(x, xs[x].l, xs[xs[x].r].l, true), xs[xs[x].r].r, xs[x].c);
    }
    if (xs[xs[x].l].c && xs[xs[xs[x].l].l].c) {
      x = create(xs[x].l, create(xs[xs[x].l].l, false), create(x, xs[xs[x].l].r, xs[x].r, false), true);
    }
    return x;
  }

  node_t erase(node_t x, T key) {
    if (!x) return 0;
    if (key == xs[x].key) {
      x = clone(x);
      xs[x].cnt--;
      x = fix(x);
    } else if (key < xs[x].key) {
      x = create(x, erase(xs[x].l, key), xs[x].r);
    } else {
      x = create(x, xs[x].l, erase(xs[x].r, key));
    }
    return x;
  }

  int countLT(node_t x, T key) {
    if (!x) return 0;
    int res = 0;
    if (xs[x].key < key) {
      res += xs[xs[x].l].sz;
      res += xs[x].cnt;
      res += countLT(xs[x].r, key);
    } else {
      res += countLT(xs[x].l, key);
    }
    return res;
  }
};
template<class T> int MultiSet<T>::curr = 1;
template<class T> vector<MSnode<T>> MultiSet<T>::xs(1);

template<class T> class Array {
public:
  Array() {}

  Array(int n) {
    h = 0;
    for (int i = 1; i < n; i *= 16) h += 4;
  }

  T *mut_get(int k) {
    auto p = mutable_get(k, root, 0, h);
    root = p.first;
    return &p.second->value;
  }

  T get(int k) {
    return immutable_get(k, root, 0, h);
  }

private:
  struct node {
    node *ch[16] = {};
    T value;

    node() {}
    node(T value) : value(value) {}
  };

  int h;
  node *root = nullptr;

  T immutable_get(int a, node *x, int l, int d) {
    if (!x) return T();
    if (d == 0) return x->value;
    int id = a - l >> d - 4;
    return immutable_get(a, x->ch[id], l + (id << d - 4), d - 4);
  }

  pair<node *, node *> mutable_get(int a, node *x, int l, int d) {
    x = x ? new node(*x) : new node();
    if (d == 0) return { x, x };
    int id = a - l >> d - 4;
    auto p = mutable_get(a, x->ch[id], l + (id << d - 4), d - 4);
    x->ch[id] = p.first;
    return { x, p.second };
  }
};

int main() {
  int q;
  cin >> q;

  vector<pair<MultiSet<int>, Array<int>>> his(q + 1);
  MultiSet<int> st;
  Array<int> ar(q);
  his[0] = make_pair(st, ar);

  map<string, int> stoi;

  auto get_id = [&](string key) {
    if (!stoi.count(key)) {
      int t = stoi.size();
      stoi[key] = t;
    }
    return stoi[key];
  };

  for (int ii = 1; ii <= q; ii++) {
    char type[30];
    char key[30];

    scanf("%s", type);

    if (type[0] == 's') {
      int v;
      scanf("%s %d", key, &v);
      int k = get_id(key);
      int pv = ar.get(k);
      if (pv != 0) {
        st.erase(pv);
      }
      *ar.mut_get(k) = v;
      st.insert(v);
    } else if (type[0] == 'r') {
      scanf("%s", key);
      int k = get_id(key);
      int v = ar.get(k);
      if (v != 0) {
        st.erase(v);
        *ar.mut_get(k) = 0;
      }
    } else if (type[0] == 'q') {
      scanf("%s", key);
      int k = get_id(key);
      int v = ar.get(k);
      if (v != 0) {
        printf("%d\n", st.countLT(v));
      } else {
        printf("-1\n");
      }
      fflush(stdout);
    } else {
      int d;
      scanf("%d", &d);

      tie(st, ar) = his[ii - d - 1];
    }
    his[ii] = make_pair(st, ar);
  }
}