World Codesprint 10: Maximum Disjoint Subtree Product
https://www.hackerrank.com/contests/world-codesprint-10/challenges/maximum-disjoint-subtree-product
解法
全方位木 DP の知名度も高まってきたので説明は省略する。以前に作成した全方位木 DP のライブラリを使用した。
#include <iostream> #include <cstdio> #include <vector> #include <functional> #include <algorithm> // add :: T -> T -> T // {v1,v2,...,vm}+vm+1 // bundle :: T -> T // u->{v1,v2,v3,...,vm} template<class T, class F1, class F2> std::vector<T> freeTreeDP(const std::vector<std::vector<int>> &g, F1 add, F2 bundle) { const int n = g.size(); std::vector<T> dp(n); std::function<void(int, int)> dfs = [&](int u, int p) { for (int v : g[u]) { if (v != p) { dfs(v, u); dp[u] = add(dp[u], dp[v]); } } dp[u] = bundle(dp[u], u); }; dfs(0, -1); std::function<void(int, int)> dfs2 = [&](int u, int p) { const int m = g[u].size(); T l; std::vector<T> r(m); for (int i = m - 2; i >= 0; i--) { r[i] = add(dp[g[u][i + 1]], r[i + 1]); } for (int i = 0; i < m; i++) { const int v = g[u][i]; dp[u] = bundle(add(l, r[i]), u); l = add(l, dp[v]); if (v != p) { dfs2(v, u); } } dp[u] = bundle(l, u); }; dfs2(0, -1); return dp; } int main() { const long long INF = 1e18; struct foo { long long max = -INF; long long min = INF; long long maxc = -INF; long long minc = INF; long long maxp; long long minp; }; int n; std::cin >> n; std::vector<long long> w(n); for (int i = 0; i < n; i++) { scanf("%lld", &w[i]); } std::vector<std::vector<int>> tree(2 * n - 1); for (int i = 0; i < n - 1; i++) { int u, v; scanf("%d %d", &u, &v); u--; v--; tree[u].push_back(i + n); tree[i + n].push_back(u); tree[v].push_back(i + n); tree[i + n].push_back(v); } auto add = [&](const foo &a, const foo &b) { foo c; c.max = std::max(a.max, b.max); c.min = std::min(a.min, b.min); c.maxc = std::max(0LL, a.maxc) + std::max(0LL, b.maxc); c.minc = std::min(0LL, a.minc) + std::min(0LL, b.minc); c.maxp = a.max * b.max; c.minp = a.min * b.min; return c; }; auto bundle = [&](const foo &a, int id) { foo c; if (id < n) { c.maxc = std::max(a.maxc + w[id], w[id]); c.minc = std::min(a.minc + w[id], w[id]); c.max = std::max(a.max, c.maxc); c.min = std::min(a.min, c.minc); } else { c = a; } return c; }; auto dp = freeTreeDP<foo>(tree, add, bundle); long long ans = -1e18; for (int i = n; i < 2 * n - 1; i++) { ans = std::max(ans, std::max(dp[i].maxp, dp[i].minp)); } std::cout << ans << std::endl; }