August Long Challenge 2017: Walks on the binary tree

計算量
\(O(Q \log^2 N)\)
キーワード
persistent segment tree | zobrist hash | binary counter | binary search on segment tree | lcp

考察

  • これはそもそもbinary stringの問題である。
  • 文字列数が2であればlen(s)+len(t)-lcp(s,t)で計算できる。複数文字列があるならソートして隣接要素間のLCPをすべて引けば良い。
  • sortはLCPがあればいいので、LCPの方法を考える。
  • ハッシュを用いてLCPを求めることにする。zobrist hash化したものをsegtreeに載せておけば、LCPは二分探索でいける。
  • 二進カウンタと同じ考え方で、更新クエリで書き換わる桁数はO(1) amortizedである。文字列をsegtreeで持つことにしていたので、永続により空間O(Q log N)で保持できるということになる。
  • segtree上の[0,k)型クエリの二分探索は工夫することでO(log^2 n)からO(log N)に落とせる。fenwick tree上の二分探索(lower_bound)として有名?
  • 先読みソート+linked listを用いて動的更新に対応させた。
#include <iostream>
#include <algorithm>
#include <vector>
#include <cassert>
#include <string>
 
using namespace std;
 
struct node {
  node *l = nullptr;
  node *r = nullptr;
  uint64_t sum = 0;
};
 
constexpr int N = 1 << 17;
 
uint64_t rnd[N];
 
uint64_t sum(node *x) {
  return x ? x->sum : 0ULL;
}
 
node *build(int l, int r) {
  node *ret = new node();
  if (r - l == 1) return ret;
  ret->l = build(l, l + r >> 1);
  ret->r = build(l + r >> 1, r);
  return ret;
}
 
int getval(int k, node *x, int l = 0, int r = N) {
  if (r - l == 1) {
    return x->sum != 0;
  }
  int m = l + r >> 1;
  if (k < m) {
    return getval(k, x->l, l, m);
  } else {
    return getval(k, x->r, m, r);
  }
}
 
node *setval(int k, int v, node *x, int l = 0, int r = N) {
  x = x ? new node(*x) : new node();
  if (r - l == 1) {
    x->sum = v ? rnd[l] : 0;
    return x;
  }
  int m = l + r >> 1;
  if (k < m) {
    x->l = setval(k, v, x->l, l, m);
  } else {
    x->r = setval(k, v, x->r, m, r);
  }
  x->sum = sum(x->l) ^ sum(x->r);
  return x;
}
 
int lcp(node *x, node *y, int k = N) {
  if (k == 1) return 0;
  if (sum(x->l) == sum(y->l)) {
    return lcp(x->r, y->r, k / 2) + k / 2;
  } else {
    return lcp(x->l, y->l, k / 2);
  }
}
 
bool compare(node *x, node *y) {
  int l = lcp(x, y);
  if (l == N - 1) return false;
  return getval(l, x) < getval(l, y);
}
 
void solve() {
  int n, q;
  cin >> n >> q;
 
  node *curr = build(0, N);
 
  vector<node *> strs(q);
 
  vector<bool> query(q);
  for (int i = 0; i < q; i++) {
    char type;
    scanf(" %c", &type);
    if (type == '!') {
      int k;
      scanf("%d", &k);
      k = n - 1 - k;
      for (; k >= 0; k--) {
        bool b = getval(k, curr);
        curr = setval(k, !b, curr);
        if (!b) break;
      }
    } else {
      query[i] = true;
    }
    strs[i] = curr;
  }
 
  vector<int> perm(q);
  for (int i = 0; i < q; i++) {
    perm[i] = i;
  }
 
  sort(perm.begin(), perm.end(), [&](int i, int j) {
    return compare(strs[i], strs[j]);
  });
 
  vector<int> inv(q);
  for (int i = 0; i < q; i++) {
    inv[perm[i]] = i;
  }
 
  long long ans = 1LL * q * n + 1;
  for (int ii = 0; ii + 1 < q; ii++) {
    int i = perm[ii];
    int j = perm[ii + 1];
    ans -= min(n, lcp(strs[i], strs[j]));
  }
 
  vector<int> left(q, -1);
  vector<int> right(q, -1);
 
  for (int i = 0; i + 1 < q; i++) {
    left[i + 1] = i;
    right[i] = i + 1;
  }
 
  auto del = [&](int i) {
    int y = inv[i];
    int x = left[y];
    int z = right[y];
    if (x != -1) ans += min(n, lcp(strs[perm[x]], strs[perm[y]]));
    if (z != -1) ans += min(n, lcp(strs[perm[y]], strs[perm[z]]));
    if (x != -1 && z != -1) ans -= min(n, lcp(strs[perm[x]], strs[perm[z]]));
    if (x != -1) right[x] = z;
    if (z != -1) left[z] = x;
    ans -= n;
  };
  for (int i = 0; i < q; i++) {
    if (query[i]) {
      del(i);
    }
  }
 
  vector<long long> anss;
  for (int i = q - 1; i >= 0; i--) {
    if (query[i]) {
      anss.push_back(ans);
    } else {
      del(i);
    }
  }
 
  for (int i = (int)anss.size() - 1; i >= 0; i--) {
    printf("%lld\n", anss[i]);
  }
}
 
uint64_t xorshift() {
  static uint64_t x = 1234567;
  x ^= x << 13;
  x ^= x >> 7;
  x ^= x << 17;
  return x;
}
 
int main() {
  for (int i = 0; i < N; i++) {
    rnd[i] = xorshift();
  }
 
  int T;
  cin >> T;
  while (T--) {
    solve();
  }
}