# CF #457 E. Jamie and Tree

http://codeforces.com/problemset/problem/916/E

LC-tree は根を変更するクエリと lca を求めるクエリを捌ける。ET-tree は辺の追加・削除と連結成分全体に x を一様加算するというクエリが捌ける。これらを組み合わせることで容易に解くことができる。

#include <iostream>
#include <algorithm>
#include <vector>
#include <map>

using namespace std;

struct ETTree {
struct node {
node *l = nullptr;
node *r = nullptr;
node *p = nullptr;
bool active;
int size;
long long val = 0;
long long lazy = 0;
long long sum = 0;

node(bool active_, int v = 0) : active(active_), size(active_), sum(v), val(v) {}
};

map<pair<int, int>, node *> edges;
vector<node *> super;
const int n;

ETTree(const vector<vector<int>> &g, vector<int> a) : n(g.size()) {
super.resize(n);
for (int i = 0; i < n; i++) {
for (int j : g[i]) {
edges[{i, j}] = new node(false);
}
super[i] = new node(true, a[i]);
}
for (int i = 0; i < n; i++) {
for (int j : g[i]) if (i < j) {
}
}
}

void link(int u, int v) {
node *uv = edges[{u, v}];
node *vu = edges[{v, u}];
node *x = super[u];
node *y = super[v];
rotate(x);
rotate(y);
join(x, uv);
join(uv, y);
join(y, vu);
}

void cut(int u, int v) {
node *uv = edges[{u, v}];
node *vu = edges[{v, u}];
rotate(uv);
split_right(uv);
split_left(vu);
split_right(vu);
}

void add(int u, long long w) {
put_lazy(splay(super[u]), w);
}

int find_size(int u) {
return splay(super[u])->size;
}

long long access(int u) {
splay(super[u]);
return splay(super[u])->sum;
}

void rot(node *x) {
node *y = x->p, *z = y->p;
if (z) {
if (y == z->l) z->l = x;
if (y == z->r) z->r = x;
}
x->p = z; y->p = x;
if (x == y->l) {
y->l = x->r; x->r = y;
if (y->l) y->l->p = y;
} else {
y->r = x->l; x->l = y;
if (y->r) y->r->p = y;
}
pull(y);
}

void push_above(node *x) {
if (x->p) push_above(x->p);
push(x);
}

node *splay(node *x) {
push_above(x);
while (x->p) {
node *y = x->p, *z = y->p;
if (z) rot((x == y->l) == (y == z->l) ? y : x);
rot(x);
}
pull(x);
return x;
}

void pull(node *x) {
x->size = x->active;
if (x->l) x->size += x->l->size;
if (x->r) x->size += x->r->size;
// assert(x->lazy == 0);
x->sum = 0;
if (x->active) {
x->sum += x->val;
}
if (x->l) x->sum += x->l->sum;
if (x->r) x->sum += x->r->sum;
}

void put_lazy(node *x, long long v) {
x->lazy += v;
if (x->active) {
x->val += v;
}
x->sum += v * x->size;
}

void push(node *x) {
if (x->l) put_lazy(x->l, x->lazy);
if (x->r) put_lazy(x->r, x->lazy);
x->lazy = 0;
}

// [_ _ x _ _] + [_ _ y _ _] -> [_ _ x _ _ _ _ y _ _]
void join(node *x, node *y) {
if (!x || !y) return;
splay(x);
while (x->r) x = x->r;
splay(x);
splay(y);
// assert(x->lazy == 0);
x->r = y;
y->p = x;
pull(x);
}

// [_ _ x _ _] -> [_ _] [x _ _]
node *split_left(node *x) {
splay(x);
// assert(x->lazy == 0);
node *y = x->l;
if (y) {
y->p = x->l = nullptr;
pull(x);
}
return y;
}

// [_ _ x _ _] -> [_ _ x] [_ _]
node *split_right(node *x) {
splay(x);
// assert(x->lazy == 0);
node *y = x->r;
if (y) {
y->p = x->r = nullptr;
pull(x);
}
return y;
}

// [0 1 2 x 3 4 5] -> [x 3 4 5 0 1 2]
void rotate(node *x) {
join(x, split_left(x));
}
};

struct node {
node *l = nullptr;
node *r = nullptr;
node *p = nullptr;
bool rev = false;
};

bool is_root(node *x) {
return !x->p || (x != x->p->l && x != x->p->r);
}

void rot(node *x) {
node *y = x->p, *z = y->p;
if (z) {
if (y == z->l) z->l = x;
if (y == z->r) z->r = x;
}
x->p = z; y->p = x;
if (x == y->l) {
y->l = x->r; x->r = y;
if (y->l) y->l->p = y;
} else {
y->r = x->l;
x->l = y;
if (y->r) y->r->p = y;
}
}

void reverse(node *x) {
swap(x->l, x->r);
x->rev ^= true;
}

void push(node *x) {
if (x->rev) {
if (x->l) reverse(x->l);
if (x->r) reverse(x->r);
}
x->rev = false;
}

void push_above(node *x) {
if (x->p) push_above(x->p);
push(x);
}

void splay(node *x) {
while (!is_root(x)) {
node *y = x->p;
if (!is_root(y)) {
node *z = y->p;
rot((x == y->l) == (y == z->l) ? y : x);
}
rot(x);
}
}

void splayLC(node *x) {
node *tmp = x;
push_above(x);
for (node *r = nullptr; x; x = x->p) {
splay(x);
x->r = r;
r = x;
}
splay(tmp);
}

void reroot(node *x) {
splayLC(x);
reverse(x);
}

void link(node *c, node *p) {
reroot(c);
splayLC(p);
p->r = c;
c->p = p;
}

void cut(node *x) {
splayLC(x);
x->l->p = nullptr;
x->l = nullptr;
}

node *lca(node *x, node *y) {
splayLC(y);
splayLC(x);
bool same = false;
node *l = y;
while (y != nullptr) {
if (is_root(y) && y->p) {
l = y->p;
}
if (y == x->r) return x;
if (x == y) same = true;
y = y->p;
}
if (!same) return nullptr;
return l;
}

node *get_parent(node *x) {
splayLC(x);
if (!x->l) {
return nullptr;
}
x = x->l;
while (x->r) {
x = x->r;
push(x);
}
splayLC(x);
return x;
}

int input() {
int n;
scanf("%d", &n);
return n;
}

int main() {
int n, q;
cin >> n;
cin >> q;
vector<node *> lc(n);
map<node *, int> mp;
mp[nullptr] = -1;
vector<vector<int>> g(n);
for (int i = 0; i < n; i++) {
lc[i] = new node();
mp[lc[i]] = i;
}
vector<int> a(n);
for (int i = 0; i < n; i++) {
a[i] = input();
}
for (int i = 0; i < n - 1; i++) {
int u = input() - 1;
int v = input() - 1;
g[u].push_back(v);
g[v].push_back(u);
}
reroot(lc[0]);

ETTree et(g, a);
while (q--) {
int type = input();
if (type == 1) {
int v = input() - 1;
reroot(lc[v]);
} else if (type == 2) {
int u = input() - 1;
int v = input() - 1;
int x = input();

int l = mp[lca(lc[u], lc[v])];
int p = mp[get_parent(lc[l])];

if (p != -1) et.cut(l, p);
if (p != -1) et.link(l, p);
} else {
int v = input() - 1;
int p = mp[get_parent(lc[v])];
if (p != -1) et.cut(v, p);
long long ans = et.access(v);
if (p != -1) et.link(v, p);
printf("%lld\n", ans);
}
}
}


#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#include <functional>

using namespace std;

int input() {
int n;
scanf("%d", &n);
return n;
}

struct HLD {
using Tree = std::vector<std::vector<int>>;

HLD(const Tree &g) : parent(g.size()), head(g.size()), vid(g.size()), inv(g.size()) {
int k = 0;
std::vector<int> size(g.size(), 1);
std::function<void(int, int)> dfs = [&](int u, int p) {
for (int v : g[u]) if (v != p) dfs(v, u), size[u] += size[v];
};
std::function<void(int, int, int)> dfs2 = [&](int u, int p, int h) {
parent[u] = p; head[u] = h; vid[u] = k++; inv[vid[u]] = u;
for (int v : g[u]) if (v != p && size[u] < size[v] * 2) dfs2(v, u, h);
for (int v : g[u]) if (v != p && size[u] >= size[v] * 2) dfs2(v, u, v);
};
dfs(0, -1);
dfs2(0, -1, 0);
}

int lca(int a, int b, int c) {
while (true) {
if (vid[a] < vid[b]) swap(a, b);
if (vid[a] < vid[c]) swap(a, c);
if (vid[b] < vid[c]) swap(b, c);
}
}

int find_parent(int u, int r) {
if (u == r) return -1;
while (vid[u] < vid[r]) {
}
return parent[u];
}
};

struct EulerTour {
std::vector<int> l;
std::vector<int> r;

EulerTour(const std::vector<std::vector<int>> &g) {
const int n = g.size();
l.resize(n);
r.resize(n);
int k = 0;
std::function<void(int, int)> dfs = [&](int u, int p) {
l[u] = k++;
for (int v : g[u]) if (v != p) dfs(v, u);
r[u] = k;
};
dfs(0, -1);
}
};

template<class T>
struct BIT {
vector<T> s0;
vector<T> s1;

BIT(int n) : s0(n + 2), s1(n + 2) {}

void add0(int k, T v) {
for (int i = k + 1; i < s0.size(); i += i & -i) {
s0[i] -= v * k;
s1[i] += v;
}
}

void add(int l, int r, T v) {
}

T sum0(int k) {
T t0 = 0;
T t1 = 0;
for (int i = k + 1; i > 0; i -= i & -i) {
t0 += s0[i];
t1 += s1[i];
}
return t0 + t1 * k;
}

T sum(int l, int r) {
return sum0(r) - sum0(l);
}

T sumall() {
return sum0(s0.size() - 2);
}
};

int main() {
int n, q;
cin >> n;
cin >> q;
vector<vector<int>> g(n);
vector<int> a(n);
for (int i = 0; i < n; i++) {
a[i] = input();
}
for (int i = 0; i < n - 1; i++) {
int u = input() - 1;
int v = input() - 1;
g[u].push_back(v);
g[v].push_back(u);
}
EulerTour et(g);
HLD hld(g);

int root = 0;
long long all = 0;

BIT<long long> bit(n);
for (int i = 0; i < n; i++) {
}

while (q--) {
int type = input();
if (type == 1) {
int v = input() - 1;
root = v;
} else if (type == 2) {
int u = input() - 1;
int v = input() - 1;
int x = input();

int l = hld.lca(u, v, root);
int p = hld.find_parent(l, root);

if (p == -1) {
all += x;
} else if (p == hld.parent[l]) {
} else {
all += x;
}
} else {
int v = input() - 1;
int p = hld.find_parent(v, root);

long long ans = 0;
if (p == -1) {
ans = all * n + bit.sumall();
} else if (p == hld.parent[v]) {
ans = all * (et.r[v] - et.l[v]) + bit.sum(et.l[v], et.r[v]);
} else {
ans += all * n + bit.sumall();
ans -= all * (et.r[p] - et.l[p]) + bit.sum(et.l[p], et.r[p]);
}

printf("%lld\n", ans);
}
}
}


ET-tree が遅い。