World Codesprint 10: Maximum Disjoint Subtree Product

https://www.hackerrank.com/contests/world-codesprint-10/challenges/maximum-disjoint-subtree-product

解法

全方位木 DP の知名度も高まってきたので説明は省略する。以前に作成した全方位木 DP のライブラリを使用した。

#include <iostream>
#include <cstdio>
#include <vector>
#include <functional>
#include <algorithm>

// add :: T -> T -> T
//     {v1,v2,...,vm}+vm+1
// bundle :: T -> T
//     u->{v1,v2,v3,...,vm} 
template<class T, class F1, class F2>
std::vector<T> freeTreeDP(const std::vector<std::vector<int>> &g, F1 add, F2 bundle) {
	const int n = g.size();
	std::vector<T> dp(n);
	std::function<void(int, int)> dfs = [&](int u, int p) {
		for (int v : g[u]) {
			if (v != p) {
				dfs(v, u);
				dp[u] = add(dp[u], dp[v]);
			}
		}
		dp[u] = bundle(dp[u], u);
	};
	dfs(0, -1);
	std::function<void(int, int)> dfs2 = [&](int u, int p) {
		const int m = g[u].size();
		T l;
		std::vector<T> r(m);
		for (int i = m - 2; i >= 0; i--) {
			r[i] = add(dp[g[u][i + 1]], r[i + 1]);
		}
		for (int i = 0; i < m; i++) {
			const int v = g[u][i];
			dp[u] = bundle(add(l, r[i]), u);
			l = add(l, dp[v]);
			if (v != p) {
				dfs2(v, u);
			}
		}
		dp[u] = bundle(l, u);
	};
	dfs2(0, -1);
	return dp;
}

int main() {
	const long long INF = 1e18;
	struct foo {
		long long max = -INF;
		long long min = INF;
		long long maxc = -INF;
		long long minc = INF;
		long long maxp;
		long long minp;
	};

	int n;
	std::cin >> n;

	std::vector<long long> w(n);
	for (int i = 0; i < n; i++) {
		scanf("%lld", &w[i]);
	}

	std::vector<std::vector<int>> tree(2 * n - 1);
	for (int i = 0; i < n - 1; i++) {
		int u, v;
		scanf("%d %d", &u, &v);
		u--;
		v--;
		tree[u].push_back(i + n);
		tree[i + n].push_back(u);
		tree[v].push_back(i + n);
		tree[i + n].push_back(v);
	}

	auto add = [&](const foo &a, const foo &b) {
		foo c;
		c.max = std::max(a.max, b.max);
		c.min = std::min(a.min, b.min);
		c.maxc = std::max(0LL, a.maxc) + std::max(0LL, b.maxc);
		c.minc = std::min(0LL, a.minc) + std::min(0LL, b.minc);
		c.maxp = a.max * b.max;
		c.minp = a.min * b.min;
		return c;
	};

	auto bundle = [&](const foo &a, int id) {
		foo c;
		if (id < n) {
			c.maxc = std::max(a.maxc + w[id], w[id]);
			c.minc = std::min(a.minc + w[id], w[id]);
			c.max = std::max(a.max, c.maxc);
			c.min = std::min(a.min, c.minc);
		} else {
			c = a;
		}
		return c;
	};

	auto dp = freeTreeDP<foo>(tree, add, bundle);

	long long ans = -1e18;
	for (int i = n; i < 2 * n - 1; i++) {
		ans = std::max(ans, std::max(dp[i].maxp, dp[i].minp));
	}
	std::cout << ans << std::endl;
}