Codeforces Round #329 (Div. 2) D. Happy Tree Party

解法

HL分解して積を頑張って計算する。正攻法ではなさそう。

HL分解の説明についてはこっちに書いておいた。pekempey.hatenablog.com

英語版Wikipediaの床関数のページを見るとfloor(floor(a/b)/c)=floor(a/bc)って書いてある。これは除算順序を交換できるという見方もできるし、先に分母の積を全部計算してから割っても等しいという見方もできる。この解法では後者の性質を用いた。

以下の様なクエリがあったとする。
f:id:pekempey:20151105200645p:plain

今回は辺に情報が乗っているので、その情報を頂点に下ろす。
f:id:pekempey:20151105201046p:plain

頂点のクエリに変えた場合、u,vのLCAを巻き込まないように気をつける必要がある。逆元が使えるからu-vパスのクエリを計算してLCAの分だけ取り除けばいい。

単純に積を計算するのは激ヤバなので、2種類のmodで計算して中国剰余定理で復元するテクを使う。しかしこれではaとa+mod1*mod2の区別がつかないので、別途logを計算して10^18より明らかに大きいと分かったらinfとして扱うようにする。ちなみに0の逆元が必要になったら詰むので、適当にmodを増やして回避する必要もありそう。(回避しなくてもACはできた)

LCA+eulertourも行ける気がする。(ただ0の逆元の対処が可能かは知らない)
こんな感じにルート間クエリが計算できる。
f:id:pekempey:20151105205210p:plain

一番最初からぐにょーと掛ければ大体打ち消して欲しい値が残る。
f:id:pekempey:20151105205216p:plain

#include <bits/stdc++.h>
#define GET_MACRO(a, b, c, NAME, ...) NAME
#define rep(...) GET_MACRO(__VA_ARGS__, rep3, rep2)(__VA_ARGS__)
#define rep2(i, a) rep3 (i, 0, a)
#define rep3(i, a, b) for (int i = (a); i < (b); i++)
#define repr(...) GET_MACRO(__VA_ARGS__, repr3, repr2)(__VA_ARGS__)
#define repr2(i, a) repr3 (i, 0, a)
#define repr3(i, a, b) for (int i = (b) - 1; i >= (a); i--)
#define chmin(a, b) ((b) < a && (a = (b), true))
#define chmax(a, b) (a < (b) && (a = (b), true))
using namespace std;
typedef long long ll;

const ll mod1 = 2147483647;
const ll mod2 = 1224736769;
const ll mod3 = 2000000011;

const ll modpow(ll a, ll b, ll mod) {
    ll res = 1;
    while (b) {
        if (b & 1) res = res * a % mod;
        a = a * a % mod;
        b /= 2;
    }
    return res;
}

const ll modinv(ll a, ll mod) {
    return modpow(a, mod - 2, mod);
}

ll garner(ll a, ll b, ll mod1, ll mod2) {
    ll x = (b - a) * modinv(mod1, mod2);
    x %= mod2; x += mod2; x %= mod2;
    return a + x * mod1;
}

struct lint {
    ll x, y, z;
    double w;
    lint() : x(1), y(1), z(1), w(0) {}
    lint(ll a) : x(a % mod1), y(a % mod2), z(a % mod3), w(log(a)) {}
    lint operator *(lint b) { 
        lint res;
        res.x = x * b.x % mod1;
        res.y = y * b.y % mod2;
        res.z = z * b.z % mod3;
        res.w = w + b.w;
        return res;
    }
    lint operator /(lint b) {
        lint res;
        res.x = x * modinv(b.x, mod1) % mod1;
        res.y = y * modinv(b.y, mod2) % mod2;
        res.z = z * modinv(b.z, mod3) % mod3;
        res.w = w - b.w;
        return res;
    }
    lint &operator *=(const lint &b) {
        x = x * b.x % mod1;
        y = y * b.y % mod2;
        z = z * b.z % mod3;
        w += b.w;
        return *this;
    }
    lint &operator /=(const lint &b) {
        x = x * modinv(b.x, mod1) % mod1;
        y = y * modinv(b.y, mod2) % mod2;
        z = z * modinv(b.z, mod3) % mod3;
        w = w - b.w;
        return *this;
    }
    operator ll() {
        const double X = 18 * log(10) + log(2.0);
        if (w >= X) return 2e18;
        ll a = garner(x, y, mod1, mod2);
        ll b = garner(y, z, mod2, mod3);
        ll c = garner(z, x, mod3, mod1);
        if (a == 0 && b == 0 && c == 0) return 0;
        if (a != 0) return a;
        if (b != 0) return b;
        if (c != 0) return b;
    }
};

struct LCA {
    vector<vector<int>> G, parent;
    vector<int> depth;
    LCA(int n) : G(n), parent(24, vector<int>(n)), depth(n) {}
    void add(int u, int v) {
        G[u].push_back(v);
        G[v].push_back(u);
    }
    void build(int root = 0) {
        dfs(root, -1);
        rep (i, 23) rep (j, G.size()) {
            if (parent[i][j] == -1) parent[i + 1][j] = -1;
            else parent[i + 1][j] = parent[i][parent[i][j]];
        }
    }
    void dfs(int curr, int prev) {
        parent[0][curr] = prev;
        for (int next : G[curr]) if (next != prev) {
            depth[next] = depth[curr] + 1;
            dfs(next, curr);
        }
    }
    int query(int u, int v) {
        if (depth[u] < depth[v]) swap(u, v);
        repr (i, 24) if (depth[u] - depth[v] >= 1 << i) {
            u = parent[i][u];
        }
        if (u == v) return u;
        repr (i, 24) if (parent[i][u] != parent[i][v]) {
            u = parent[i][u];
            v = parent[i][v];
        }
        return parent[0][u];
    }
};

struct HL {
    vector<vector<int>> G, path;
    vector<int> parent, head, order, heads, heavy, depth;
    int root;
    HL(int n) : G(n), heavy(n, -1), parent(n), head(n), order(n), path(n), depth(n) {}
    void add(int u, int v) {
        G[u].push_back(v);
        G[v].push_back(u);
    }
    int dfs(int curr, int prev = -1) {
        int res = 1;
        head[curr] = curr;
        parent[curr] = prev;
        int maxv = -1;
        for (int next : G[curr]) if (next != prev) {
            int d = dfs(next, curr);
            res += d;
            if (maxv < d) maxv = d, heavy[curr] = next;
        }
        return res;
    }
    void dfs2(int curr, int prev = -1) {
        if (head[curr] == curr) heads.push_back(curr);
        path[head[curr]].push_back(curr);
        for (int next : G[curr]) if (next != prev) {
            if (next == heavy[curr]) {
                head[next] = head[curr];
                order[next] = order[curr] + 1;  
            }
            depth[next] = depth[curr] + (next != heavy[curr]);
            dfs2(next, curr);
        }
    }
    void build(int root = 0) {
        this->root = root;
        dfs(root);
        dfs2(root);
    }
    struct Iterator {
        int u, v;
        HL *hl;
        Iterator(HL *hl, int u, int v) : hl(hl), u(u), v(v) {}
        // head, [from, to)
        tuple<int, int, int> next() {
            if (hl->depth[u] == hl->depth[v]) {
                if (hl->head[u] == hl->head[v]) {
                    auto m = minmax(hl->order[u], hl->order[v]);
                    u = -1;
                    return make_tuple(hl->head[v], m.first, m.second + 1);
                }
            }
            if (hl->depth[u] < hl->depth[v]) swap(u, v);
            int pu = u;
            u = hl->parent[hl->head[u]];
            return make_tuple(hl->head[pu], 0, hl->order[pu] + 1);
        }
        bool has_next() {
            return u != -1;
        }
    };
    Iterator iterator(int u, int v) {
        return Iterator(this, u, v);    
    }
};

struct SegmentTree {
    vector<lint> prod;
    int size;

    void init(int n) {
        size = 1;
        while (size < n) size *= 2;
        prod.resize(size * 2);
    }

    void set(int k, ll v) {
        k += size - 1;
        prod[k] = lint(v);
    }

    void build() {
        repr (k, size - 1) {
            prod[k] = prod[k * 2 + 1] * prod[k * 2 + 2];
        }
    }

    lint get(int k) {
        k += size - 1;
        return prod[k];
    }

    void update(int k, ll v) {
        k += size - 1;
        prod[k] = lint(v);
        while (k) {
            k = (k - 1) / 2;
            prod[k] = prod[k * 2 + 1] * prod[k * 2 + 2];
        }
    }

    lint query(int a, int b, int k, int l, int r) {
        if (r <= a || b <= l) return lint();
        if (a <= l && r <= b) return prod[k];
        lint x = query(a, b, k * 2 + 1, l, (l + r) / 2);
        lint y = query(a, b, k * 2 + 2, (l + r) / 2, r);
        return x * y;
    }

    lint query(int a, int b) {
        return query(a, b, 0, 0, size);
    }
};

struct HLSolver {
    HL hl;
    vector<SegmentTree> tr;
    HLSolver(int n) : hl(n), tr(n) {}
    void add(int u, int v) {
        hl.add(u, v);
    }
    void build() {
        hl.build();
        for (int h : hl.heads) {
            tr[h].init(hl.path[h].size());
        }
    }
    void set(int u, ll x) {
        tr[hl.head[u]].set(hl.order[u], x);
    }
    void build_segment_tree() {
        for (int h : hl.heads) {
            tr[h].build();
        }
    }
    void update(int u, ll x) {
        tr[hl.head[u]].update(hl.order[u], x);
    }
    lint get(int u) {
        return tr[hl.head[u]].get(hl.order[u]);
    }
    ll query(int u, int v, int l) {
        lint res;
        res /= get(l);
        auto it = hl.iterator(u, v);
        while (it.has_next()) {
            int h, l, r;
            tie(h, l, r) = it.next();
            res *= tr[h].query(l, r);
        }   
        return (ll)res;
    }
};

int depth[202020];
vector<int> G[202020];

void dfs(int curr, int prev = -1) {
    for (int next : G[curr]) if (next != prev) {
        depth[next] = depth[curr] + 1;
        dfs(next, curr);
    }
}

int main() {
    int n, m;
    cin >> n >> m;
    vector<pair<int, int>> es(n - 1);
    vector<ll> val(n - 1);
    HLSolver hl(n);
    LCA lca(n);
    rep (i, n - 1) {
        int u, v;
        ll x;
        scanf("%d %d %I64d", &u, &v, &x);
        u--; v--;
        G[u].push_back(v);
        G[v].push_back(u);
        es[i] = make_pair(u, v);
        hl.add(u, v);
        val[i] = x;
        lca.add(u, v);
    }
    hl.build();
    lca.build();
    dfs(0);
    rep (i, n - 1) {
        int u, v;
        tie(u, v) = es[i];
        if (depth[u] > depth[v]) swap(es[i].first, es[i].second);
    }
    rep (i, n - 1) {
        int u, v;
        tie(u, v) = es[i];
        hl.set(v, val[i]);
    }
    hl.build_segment_tree();

    rep (i_, m) {
        int qid;
        scanf("%d", &qid);

        if (qid == 1) {
            int a, b;
            ll y;
            scanf("%d %d %I64d", &a, &b, &y);
            a--; b--;
            int l = lca.query(a, b);
            ll ans = hl.query(a, b, l);
            ans = y / ans;
            printf("%I64d\n", ans);
        } else {
            int p;
            ll c;
            scanf("%d %I64d", &p, &c);
            int u, v;
            tie(u, v) = es[p - 1];
            hl.update(v, c);
        }
    }
    return 0;
}

コメント

本番中はGarnerがバグってたり、辺を親子順にスワップしたつもりがスワップできてなかったりで結局WAが取れなかった。logの精度が気になるけど実際のところ安全なんだろうか。