CF #457 E. Jamie and Tree

http://codeforces.com/problemset/problem/916/E

LC-tree は根を変更するクエリと lca を求めるクエリを捌ける。ET-tree は辺の追加・削除と連結成分全体に x を一様加算するというクエリが捌ける。これらを組み合わせることで容易に解くことができる。

#include <iostream>
#include <algorithm>
#include <vector>
#include <map>

using namespace std;

struct ETTree {
  struct node {
    node *l = nullptr;
    node *r = nullptr;
    node *p = nullptr;
    bool active;
    int size;
    long long val = 0;
    long long lazy = 0;
    long long sum = 0;

    node(bool active_, int v = 0) : active(active_), size(active_), sum(v), val(v) {}
  };

  map<pair<int, int>, node *> edges;
  vector<node *> super;
  const int n;

  ETTree(const vector<vector<int>> &g, vector<int> a) : n(g.size()) {
    super.resize(n);
    for (int i = 0; i < n; i++) {
      for (int j : g[i]) {
        edges[{i, j}] = new node(false);
      }
      super[i] = new node(true, a[i]);
    }
    for (int i = 0; i < n; i++) {
      for (int j : g[i]) if (i < j) {
        link(i, j);
      }
    }
  }

  void link(int u, int v) {
    node *uv = edges[{u, v}];
    node *vu = edges[{v, u}];
    node *x = super[u];
    node *y = super[v];
    rotate(x);
    rotate(y);
    join(x, uv);
    join(uv, y);
    join(y, vu);
  }

  void cut(int u, int v) {
    node *uv = edges[{u, v}];
    node *vu = edges[{v, u}];
    rotate(uv);
    split_right(uv);
    split_left(vu);
    split_right(vu);
  }

  void add(int u, long long w) {
    put_lazy(splay(super[u]), w);
  }

  int find_size(int u) {
    return splay(super[u])->size;
  }

  long long access(int u) {
    splay(super[u]);
    return splay(super[u])->sum;
  }

  void rot(node *x) {
    node *y = x->p, *z = y->p;
    if (z) {
      if (y == z->l) z->l = x;
      if (y == z->r) z->r = x;
    }
    x->p = z; y->p = x;
    if (x == y->l) {
      y->l = x->r; x->r = y;
      if (y->l) y->l->p = y;
    } else {
      y->r = x->l; x->l = y;
      if (y->r) y->r->p = y;
    }
    pull(y);
  }

  void push_above(node *x) {
    if (x->p) push_above(x->p);
    push(x);
  }

  node *splay(node *x) {
    push_above(x);
    while (x->p) {
      node *y = x->p, *z = y->p;
      if (z) rot((x == y->l) == (y == z->l) ? y : x);
      rot(x);
    }
    pull(x);
    return x;
  }

  void pull(node *x) {
    x->size = x->active;
    if (x->l) x->size += x->l->size;
    if (x->r) x->size += x->r->size;
    // assert(x->lazy == 0);
    x->sum = 0;
    if (x->active) {
      x->sum += x->val;
    }
    if (x->l) x->sum += x->l->sum;
    if (x->r) x->sum += x->r->sum;
  }

  void put_lazy(node *x, long long v) {
    x->lazy += v;
    if (x->active) {
      x->val += v;
    }
    x->sum += v * x->size;
  }

  void push(node *x) {
    if (x->l) put_lazy(x->l, x->lazy);
    if (x->r) put_lazy(x->r, x->lazy);
    x->lazy = 0;
  }

  // [_ _ x _ _] + [_ _ y _ _] -> [_ _ x _ _ _ _ y _ _]
  void join(node *x, node *y) {
    if (!x || !y) return;
    splay(x);
    while (x->r) x = x->r;
    splay(x);
    splay(y);
    // assert(x->lazy == 0);
    x->r = y;
    y->p = x;
    pull(x);
  }

  // [_ _ x _ _] -> [_ _] [x _ _]
  node *split_left(node *x) {
    splay(x);
    // assert(x->lazy == 0);
    node *y = x->l;
    if (y) {
      y->p = x->l = nullptr;
      pull(x);
    }
    return y;
  }

  // [_ _ x _ _] -> [_ _ x] [_ _]
  node *split_right(node *x) {
    splay(x);
    // assert(x->lazy == 0);
    node *y = x->r;
    if (y) {
      y->p = x->r = nullptr;
      pull(x);
    }
    return y;
  }

  // [0 1 2 x 3 4 5] -> [x 3 4 5 0 1 2]
  void rotate(node *x) {
    join(x, split_left(x));
  }
};

struct node {
  node *l = nullptr;
  node *r = nullptr;
  node *p = nullptr;
  bool rev = false;
};

bool is_root(node *x) {
  return !x->p || (x != x->p->l && x != x->p->r);
}

void rot(node *x) {
  node *y = x->p, *z = y->p;
  if (z) {
    if (y == z->l) z->l = x;
    if (y == z->r) z->r = x;
  }
  x->p = z; y->p = x;
  if (x == y->l) {
    y->l = x->r; x->r = y;
    if (y->l) y->l->p = y;
  } else {
    y->r = x->l;
    x->l = y;
    if (y->r) y->r->p = y;
  }
}

void reverse(node *x) {
  swap(x->l, x->r);
  x->rev ^= true;
}

void push(node *x) {
  if (x->rev) {
    if (x->l) reverse(x->l);
    if (x->r) reverse(x->r);
  }
  x->rev = false;
}

void push_above(node *x) {
  if (x->p) push_above(x->p);
  push(x);
}

void splay(node *x) {
  while (!is_root(x)) {
    node *y = x->p;
    if (!is_root(y)) {
      node *z = y->p;
      rot((x == y->l) == (y == z->l) ? y : x);
    }
    rot(x);
  }
}

void splayLC(node *x) {
  node *tmp = x;
  push_above(x);
  for (node *r = nullptr; x; x = x->p) {
    splay(x);
    x->r = r;
    r = x;
  }
  splay(tmp);
}

void reroot(node *x) {
  splayLC(x);
  reverse(x);
}

void link(node *c, node *p) {
  reroot(c);
  splayLC(p);
  p->r = c;
  c->p = p;
}

void cut(node *x) {
  splayLC(x);
  x->l->p = nullptr;
  x->l = nullptr;
}

node *lca(node *x, node *y) {
  splayLC(y);
  splayLC(x);
  bool same = false;
  node *l = y;
  while (y != nullptr) {
    if (is_root(y) && y->p) {
      l = y->p;
    }
    if (y == x->r) return x;
    if (x == y) same = true;
    y = y->p;
  }
  if (!same) return nullptr;
  return l;
}

node *get_parent(node *x) {
  splayLC(x);
  if (!x->l) {
    return nullptr;
  }
  x = x->l;
  while (x->r) {
    x = x->r;
    push(x);
  }
  splayLC(x);
  return x;
}

int input() {
  int n;
  scanf("%d", &n);
  return n;
}

int main() {
  int n, q;
  cin >> n;
  cin >> q;
  vector<node *> lc(n);
  map<node *, int> mp;
  mp[nullptr] = -1;
  vector<vector<int>> g(n);
  for (int i = 0; i < n; i++) {
    lc[i] = new node();
    mp[lc[i]] = i;
  }
  vector<int> a(n);
  for (int i = 0; i < n; i++) {
    a[i] = input();
  }
  for (int i = 0; i < n - 1; i++) {
    int u = input() - 1;
    int v = input() - 1;
    g[u].push_back(v);
    g[v].push_back(u);
    link(lc[u], lc[v]);
  }
  reroot(lc[0]);

  ETTree et(g, a);
  while (q--) {
    int type = input();
    if (type == 1) {
      int v = input() - 1;
      reroot(lc[v]);
    } else if (type == 2) {
      int u = input() - 1;
      int v = input() - 1;
      int x = input();

      int l = mp[lca(lc[u], lc[v])];
      int p = mp[get_parent(lc[l])];

      if (p != -1) et.cut(l, p);
      et.add(l, x);
      if (p != -1) et.link(l, p);
    } else {
      int v = input() - 1;
      int p = mp[get_parent(lc[v])];
      if (p != -1) et.cut(v, p);
      long long ans = et.access(v);
      if (p != -1) et.link(v, p);
      printf("%lld\n", ans);
    }
  }
}

恐らく普通の解。

#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#include <functional>

using namespace std;

int input() {
  int n;
  scanf("%d", &n);
  return n;
}

struct HLD {
  using Tree = std::vector<std::vector<int>>;
  std::vector<int> parent, head, vid, inv;

  HLD(const Tree &g) : parent(g.size()), head(g.size()), vid(g.size()), inv(g.size()) {
    int k = 0;
    std::vector<int> size(g.size(), 1);
    std::function<void(int, int)> dfs = [&](int u, int p) {
      for (int v : g[u]) if (v != p) dfs(v, u), size[u] += size[v];
    };
    std::function<void(int, int, int)> dfs2 = [&](int u, int p, int h) {
      parent[u] = p; head[u] = h; vid[u] = k++; inv[vid[u]] = u;
      for (int v : g[u]) if (v != p && size[u] < size[v] * 2) dfs2(v, u, h);
      for (int v : g[u]) if (v != p && size[u] >= size[v] * 2) dfs2(v, u, v);
    };
    dfs(0, -1);
    dfs2(0, -1, 0);
  }

  int lca(int a, int b, int c) {
    while (true) {
      if (vid[a] < vid[b]) swap(a, b);
      if (vid[a] < vid[c]) swap(a, c);
      if (vid[b] < vid[c]) swap(b, c);
      if (head[a] == head[b]) return b;
      a = parent[head[a]];
    }
  }

  int find_parent(int u, int r) {
    if (u == r) return -1;
    while (vid[u] < vid[r]) {
      if (head[u] == head[r]) return inv[vid[u] + 1];
      if (parent[head[r]] == u) return head[r];
      r = parent[head[r]];
    }
    return parent[u];
  }
};

struct EulerTour {
  std::vector<int> l;
  std::vector<int> r;

  EulerTour(const std::vector<std::vector<int>> &g) {
    const int n = g.size();
    l.resize(n);
    r.resize(n);
    int k = 0;
    std::function<void(int, int)> dfs = [&](int u, int p) {
      l[u] = k++;
      for (int v : g[u]) if (v != p) dfs(v, u);
      r[u] = k;
    };
    dfs(0, -1);
  }
};

template<class T>
struct BIT {
  vector<T> s0;
  vector<T> s1;

  BIT(int n) : s0(n + 2), s1(n + 2) {}
  
  void add0(int k, T v) {
    for (int i = k + 1; i < s0.size(); i += i & -i) {
      s0[i] -= v * k;
      s1[i] += v;
    }
  }

  void add(int l, int r, T v) {
    add0(l, v);
    add0(r, -v);
  }

  T sum0(int k) {
    T t0 = 0;
    T t1 = 0;
    for (int i = k + 1; i > 0; i -= i & -i) {
      t0 += s0[i];
      t1 += s1[i];
    }
    return t0 + t1 * k;
  }

  T sum(int l, int r) {
    return sum0(r) - sum0(l);
  }

  T sumall() {
    return sum0(s0.size() - 2);
  }
};

int main() {
  int n, q;
  cin >> n;
  cin >> q;
  vector<vector<int>> g(n);
  vector<int> a(n);
  for (int i = 0; i < n; i++) {
    a[i] = input();
  }
  for (int i = 0; i < n - 1; i++) {
    int u = input() - 1;
    int v = input() - 1;
    g[u].push_back(v);
    g[v].push_back(u);
  }
  EulerTour et(g);
  HLD hld(g);

  int root = 0;
  long long all = 0;

  BIT<long long> bit(n);
  for (int i = 0; i < n; i++) {
    bit.add(et.l[i], et.l[i] + 1, a[i]);
  }

  while (q--) {
    int type = input();
    if (type == 1) {
      int v = input() - 1;
      root = v;
    } else if (type == 2) {
      int u = input() - 1;
      int v = input() - 1;
      int x = input();

      int l = hld.lca(u, v, root);
      int p = hld.find_parent(l, root);

      if (p == -1) {
        all += x;
      } else if (p == hld.parent[l]) {
        bit.add(et.l[l], et.r[l], x);
      } else {
        bit.add(et.l[p], et.r[p], -x);
        all += x;
      }
    } else {
      int v = input() - 1;
      int p = hld.find_parent(v, root);

      long long ans = 0;
      if (p == -1) {
        ans = all * n + bit.sumall();
      } else if (p == hld.parent[v]) {
        ans = all * (et.r[v] - et.l[v]) + bit.sum(et.l[v], et.r[v]);
      } else {
        ans += all * n + bit.sumall();
        ans -= all * (et.r[p] - et.l[p]) + bit.sum(et.l[p], et.r[p]);
      }

      printf("%lld\n", ans);
    }
  }
}

感想

ET-tree が遅い。