pekempeyのブログ

競技プログラミングに関する話題を書いていきます。

Codeforces #419 (Div.1) C. Karen and Supermarket

http://codeforces.com/contest/815/problem/C

解法

最適解は以下のような形である(入力が木になっていることに注意)。

f:id:pekempey:20170618152841p:plain

ここまで分かると dp[頂点][使用した頂点数][根と繋がっているかどうか] という DP ができることが分かる。しかしこれは O(n^3) ではないだろうか。いや、そうではなく O(n^2) になる。

二乗の木 DP - (iwi) { 反省します - TopCoder部

数式で計算量解析するのは、正しいことを確認するだけなら良いのだが本質を見失いやすい。なぜ O(n^2) になるのか直感的に説明しよう。

以下のような木がある。木DPではこれらの頂点をどんどんマージしていく。

f:id:pekempey:20170618153825p:plain

f:id:pekempey:20170618153952p:plain

f:id:pekempey:20170618154026p:plain

f:id:pekempey:20170618154048p:plain

f:id:pekempey:20170618154122p:plain

6,8 と 9,11 をマージするとき 2×2 回のループが回るというのは、(6,9),(6,11),(8,9),(8,11)というペア組を行うことに対応している(これは直積のサイズ)

マージしていく過程で、同じグループに含まれているもの同士は再ペア組されることはないから、n×n 回しかペア組は行われない。つまり計算量は O(n^2) である。

#include <iostream>
#include <cstdio>
#include <vector>
#include <string>
#include <algorithm>
#include <functional>

constexpr int INF = 1.01e9;

void upd(int &x, int y) {
	x = std::min(x, y);
}

std::vector<int> mul(std::vector<int> &x, std::vector<int> &y) {
	std::vector<int> z(x.size() + y.size() - 1, INF);
	for (int i = 0; i < x.size(); i++) {
		for (int j = 0; j < y.size(); j++) {
			upd(z[i + j], x[i] + y[j]);
		}
	}
	return z;
}

int main() {
	int n, w;
	std::cin >> n >> w;

	std::vector<std::vector<int>> g(n);
	std::vector<int> a(n), b(n);
	for (int i = 0; i < n; i++) {
		scanf("%d %d", &a[i], &b[i]);
		b[i] = a[i] - b[i];
		if (i > 0) {
			int p;
			scanf("%d", &p);
			g[p - 1].push_back(i);
		}
	}

	std::vector<std::vector<int>> dp0(n);
	std::vector<std::vector<int>> dp1(n);
	std::vector<std::vector<int>> dp01(n);

	std::function<void(int)> dfs = [&](int u) {
		dp0[u] = { 0, a[u] };
		dp1[u] = { INF, b[u] };
		for (int v : g[u]) {
			dfs(v);
			dp1[u] = mul(dp1[u], dp01[v]);
			dp0[u] = mul(dp0[u], dp0[v]);
		}
		dp01[u].resize(dp0[u].size());
		for (int i = 0; i < dp0[u].size(); i++) {
			dp01[u][i] = std::min(dp0[u][i], dp1[u][i]);
		}
	};
	dfs(0);

	int ans = 0;
	for (int i = 0; i < dp01[0].size(); i++) {
		if (dp01[0][i] <= w) {
			ans = i;
		}
	}
	std::cout << ans << std::endl;
}