HourRank 20: Birjik and Nicole's Tree Game

解法

HL 分解は LCA を求める部分にしか使っていない。

#include <iostream>
#include <algorithm>
#include <string>
#include <vector>
#include <functional>
#include <queue>
#include <stack>
#include <map>
#include <set>

struct HLDecomp {
using Tree = std::vector<std::vector<int>>;
const Tree &tree;

std::vector<int> parent;
std::vector<int> vid;

HLDecomp(const Tree &tree) : tree(tree) {
const int n = tree.size();
const int root = 0;
std::stack<std::pair<int, int>> stack;
stack.emplace(root, 0);

parent.assign(n, -1);
vid.assign(n, -1);

std::vector<int> heavy(n, -1);
std::vector<int> size(n, 1);
while (!stack.empty()) {
const int u = stack.top().first;
const int i = stack.top().second;
if (i < tree[u].size()) {
stack.top().second++;
const int v = tree[u][i];
if (v != parent[u]) {
parent[v] = u;
stack.emplace(v, 0);
}
} else {
stack.pop();
int max = 0;
for (int v : tree[u]) {
if (v != parent[u]) {
size[u] += size[v];
if (max < size[v]) {
max = size[v];
heavy[u] = v;
}
}
}
}
}

std::queue<int> queue;
queue.push(0);
int now = 0;
while (!queue.empty()) {
const int h = queue.front();
queue.pop();
for (int i = h; i != -1; i = heavy[i]) {
vid[i] = now++;
for (int j : tree[i]) {
if (j != parent[i] && j != heavy[i]) {
queue.push(j);
}
}
}
}
}

template<typename T>
void foreach(int u, int v, T func) {
while (true) {
if (vid[u] > vid[v]) {
std::swap(u, v);
}
} else {
func(vid[u], vid[v]);
break;
}
}
}

int lca(int u, int v) {
while (true) {
if (vid[u] > vid[v]) {
std::swap(u, v);
}
} else {
return u;
}
}
}
};

int main() {
int n;
std::cin >> n;

std::vector<std::vector<int>> tree(n);
for (int i = 0; i < n - 1; i++) {
int u, v;
scanf("%d %d", &u, &v);
u--;
v--;
tree[u].push_back(v);
tree[v].push_back(u);
}

HLDecomp hld(tree);
int q;
std::cin >> q;

std::vector<int> post(n);
std::vector<int> ord;
std::vector<int> depth(n);
int k = 0;

std::function<void(int, int)> dfs = [&](int u, int p) {
for (int v : tree[u]) {
if (v != p) {
depth[v] = depth[u] + 1;
dfs(v, u);
}
}
post[u] = k++;
ord.push_back(u);
};
dfs(0, -1);

std::vector<int> imos(n);
while (q--) {
int k;
scanf("%d", &k);
std::set<int> q;
q.insert(post[0]);
for (int i = 0; i < k; i++) {
int t;
scanf("%d", &t);
t--;
q.insert(post[t]);
imos[t]++;
}
std::vector<int> a;
std::map<int, int> parent;
while (!q.empty()) {
int x = ord[*q.begin()];
q.erase(q.begin());
a.push_back(x);
if (!q.empty()) {
int y = ord[*q.begin()];
int z = hld.lca(x, y);
if (z != y) {
parent[y] = z;
q.insert(post[z]);
}
parent[x] = z;
}
}
std::vector<int> ans(k + 1);
for (int i : a) {
if (i != 0) {
int p = parent[i];
imos[p] += imos[i];
ans[imos[i]] += depth[i] - depth[p];
} else {
ans[imos[i]]++;
}
}
ans[0] = n;
for (int i = 1; i < k + 1; i++) {
ans[0] -= ans[i];
}
for (int i = 0; i < k + 1; i++) {
printf("%d ", ans[i]);
}
printf("\n");
for (int i : a) {
imos[i] = 0;
}
}
}