考察
- これはそもそもbinary stringの問題である。
- 文字列数が2であれば
len(s)+len(t)-lcp(s,t)
で計算できる。複数文字列があるならソートして隣接要素間のLCPをすべて引けば良い。
- sortはLCPがあればいいので、LCPの方法を考える。
- ハッシュを用いてLCPを求めることにする。zobrist hash化したものをsegtreeに載せておけば、LCPは二分探索でいける。
- 二進カウンタと同じ考え方で、更新クエリで書き換わる桁数は
O(1) amortized
である。文字列をsegtreeで持つことにしていたので、永続により空間O(Q log N)
で保持できるということになる。
- segtree上の[0,k)型クエリの二分探索は工夫することで
O(log^2 n)
からO(log N)
に落とせる。fenwick tree上の二分探索(lower_bound)として有名?
- 先読みソート+linked listを用いて動的更新に対応させた。
#include <iostream>
#include <algorithm>
#include <vector>
#include <cassert>
#include <string>
using namespace std;
struct node {
node *l = nullptr;
node *r = nullptr;
uint64_t sum = 0;
};
constexpr int N = 1 << 17;
uint64_t rnd[N];
uint64_t sum(node *x) {
return x ? x->sum : 0ULL;
}
node *build(int l, int r) {
node *ret = new node();
if (r - l == 1) return ret;
ret->l = build(l, l + r >> 1);
ret->r = build(l + r >> 1, r);
return ret;
}
int getval(int k, node *x, int l = 0, int r = N) {
if (r - l == 1) {
return x->sum != 0;
}
int m = l + r >> 1;
if (k < m) {
return getval(k, x->l, l, m);
} else {
return getval(k, x->r, m, r);
}
}
node *setval(int k, int v, node *x, int l = 0, int r = N) {
x = x ? new node(*x) : new node();
if (r - l == 1) {
x->sum = v ? rnd[l] : 0;
return x;
}
int m = l + r >> 1;
if (k < m) {
x->l = setval(k, v, x->l, l, m);
} else {
x->r = setval(k, v, x->r, m, r);
}
x->sum = sum(x->l) ^ sum(x->r);
return x;
}
int lcp(node *x, node *y, int k = N) {
if (k == 1) return 0;
if (sum(x->l) == sum(y->l)) {
return lcp(x->r, y->r, k / 2) + k / 2;
} else {
return lcp(x->l, y->l, k / 2);
}
}
bool compare(node *x, node *y) {
int l = lcp(x, y);
if (l == N - 1) return false;
return getval(l, x) < getval(l, y);
}
void solve() {
int n, q;
cin >> n >> q;
node *curr = build(0, N);
vector<node *> strs(q);
vector<bool> query(q);
for (int i = 0; i < q; i++) {
char type;
scanf(" %c", &type);
if (type == '!') {
int k;
scanf("%d", &k);
k = n - 1 - k;
for (; k >= 0; k--) {
bool b = getval(k, curr);
curr = setval(k, !b, curr);
if (!b) break;
}
} else {
query[i] = true;
}
strs[i] = curr;
}
vector<int> perm(q);
for (int i = 0; i < q; i++) {
perm[i] = i;
}
sort(perm.begin(), perm.end(), [&](int i, int j) {
return compare(strs[i], strs[j]);
});
vector<int> inv(q);
for (int i = 0; i < q; i++) {
inv[perm[i]] = i;
}
long long ans = 1LL * q * n + 1;
for (int ii = 0; ii + 1 < q; ii++) {
int i = perm[ii];
int j = perm[ii + 1];
ans -= min(n, lcp(strs[i], strs[j]));
}
vector<int> left(q, -1);
vector<int> right(q, -1);
for (int i = 0; i + 1 < q; i++) {
left[i + 1] = i;
right[i] = i + 1;
}
auto del = [&](int i) {
int y = inv[i];
int x = left[y];
int z = right[y];
if (x != -1) ans += min(n, lcp(strs[perm[x]], strs[perm[y]]));
if (z != -1) ans += min(n, lcp(strs[perm[y]], strs[perm[z]]));
if (x != -1 && z != -1) ans -= min(n, lcp(strs[perm[x]], strs[perm[z]]));
if (x != -1) right[x] = z;
if (z != -1) left[z] = x;
ans -= n;
};
for (int i = 0; i < q; i++) {
if (query[i]) {
del(i);
}
}
vector<long long> anss;
for (int i = q - 1; i >= 0; i--) {
if (query[i]) {
anss.push_back(ans);
} else {
del(i);
}
}
for (int i = (int)anss.size() - 1; i >= 0; i--) {
printf("%lld\n", anss[i]);
}
}
uint64_t xorshift() {
static uint64_t x = 1234567;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
return x;
}
int main() {
for (int i = 0; i < N; i++) {
rnd[i] = xorshift();
}
int T;
cin >> T;
while (T--) {
solve();
}
}