読者です 読者をやめる 読者になる 読者になる

pekempeyのブログ

競技プログラミングに関する話題を書いていきます。

動的計画法入門(数え上げ)

数え上げ系の DP について説明する。

この記事は DP 初心者を対象にしている。DP やるだけみたいなことを一度でも考えたことがある人は対象にしていない。

例題

次の問題を考えてみよう。

  • N 桁以下の 3 の倍数はいくつあるか。

N が小さければ全列挙できる。i 桁の数がすべて列挙できていれば、その後ろに 0~9 を付け足せば i+1 桁の数をすべて列挙できることを使う。

  • dp[i] := leading zero を含めて i 桁のすべての非負整数の集合
#include <iostream>
#include <string>
#include <vector>
using namespace std;

int modulo(string s, int mod) {
    int ret = 0;
    for (char c : s) ret = (ret * 10 + (c - '0')) % mod;
    return ret;
}

int main() {
    int n;
    cin >> n;

    static vector<string> dp[101010]; // i 桁の数すべて
    dp[0].push_back("");
    for (int i = 0; i < n; i++) {
        for (string s : dp[i]) {
            for (char d = '0'; d <= '9'; d++) {
                dp[i + 1].push_back(s + d);
            }
        }
    }

    int ans = 0;
    for (string s : dp[n]) if (modulo(s, 3) == 0) ans++;
    cout << ans << endl;

    return 0;
}

すでによく見る DP っぽい。この全列挙こそが数え上げ系 DP のベースになる。 これをベースに競プロっぽい DP まで修正していこう。

最後に 3 の倍数判定をしているのが無駄な気がする。そこで

  • dp[i][j] := i 桁以下で 3 で割った余りが j であるような非負整数の集合

として、最終的な結果が dp[N][0] に入るようにしよう。

#include <iostream>
#include <string>
#include <vector>
using namespace std;

int main() {
    int n;
    cin >> n;

    static vector<string> dp[101010][3]; // i 桁で mod 3 が j になる
    dp[0][0].push_back("");
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < 3; j++) {
            for (string s : dp[i][j]) {
                for (char d = '0'; d <= '9'; d++) {
                    dp[i + 1][(j + (d - '0')) % 3].push_back(s + d);
                }
            }
        }
    }

    int ans = dp[n][0].size();
    cout << ans << endl;

    return 0;
}

この処理に具体的な値なんて必要ない。i 桁で mod 3 が j になるような数の総数だけ覚えておこう。

#include <iostream>
#include <string>
#include <vector>
using namespace std;

int main() {
    int n;
    cin >> n;

    static int dp[101010][3]; // i 桁で mod 3 が j になる
    dp[0][0] = 1;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < 3; j++) {
            for (int k = 0; k < dp[i][j]; k++) {
                for (int d = 0; d <= 9; d++) {
                    dp[i + 1][(j + d) % 3]++;
                }
            }
        }
    }

    int ans = dp[n][0];
    cout << ans << endl;

    return 0;
}

dp[i][j] 回ループする必要もなさそう。

#include <iostream>
#include <string>
#include <vector>
using namespace std;

int main() {
    int n;
    cin >> n;

    static int dp[101010][3]; // i 桁で mod 3 が j になる
    dp[0][0] = 1;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < 3; j++) {
            for (int d = 0; d <= 9; d++) {
                dp[i + 1][(j + d) % 3] += dp[i][j];
            }
        }
    }

    int ans = dp[n][0];
    cout << ans << endl;

    return 0;
}

これで完成。

DP を考える上で最も重要なのは、具体的な値が分からなくても処理できるようなグループ分けをすること。

グループ分けの方法というのは意外と似たようなものが多い。 だから DP は数をこなせば解けるようになる。

おまけ

次のような問題も考えてみよう。

  • A 以下の 3 の倍数はいくつあるか。

先ほどと同様に i 桁で mod 3 が j になるようなものを全列挙してから、A 以下のものがいくつあるか数えればよい。

#include <iostream>
#include <string>
#include <vector>
using namespace std;

int main() {
    string A;
    cin >> A;
    int n = A.size();

    static vector<string> dp[101010][3]; // i 桁で mod 3 が j になる
    dp[0][0].push_back("");
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < 3; j++) {
            for (string s : dp[i][j]) {
                for (char d = '0'; d <= '9'; d++) {
                    dp[i + 1][(j + d) % 3].push_back(s + d);
                }
            }
        }
    }

    int ans = 0;
    for (string s : dp[n][0]) if (s <= A) ans++;
    cout << ans << endl;

    return 0;
}

しかしこの処理は具体的な値に依存している。最後に A と比較しているのが DP に向いていない。一番最後に A と比較するのではなく、A と比較しながら列挙していけばよい。

  • dp[i 桁][A 未満であることが確定][mod 3 の値] := このパターンに当てはまる数の総数

こうすると具体的な値を覚える必要がなくなる。とはいえ、やや難しいのでコードが理解できなくても気にすることはない。

#include <iostream>
#include <string>
#include <vector>
using namespace std;

int main() {
    string A;
    cin >> A;
    int n = A.size();

    static int dp[101010][2][3];
    dp[0][0][0] = 1;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < 2; j++) {
            for (int k = 0; k < 3; k++) {
                int lim = j ? 9 : A[i] - '0';
                for (int d = 0; d <= lim; d++) {
                    dp[i + 1][j || d < lim][(k + d) % 3] += dp[i][j][k];
                }
            }
        }
    }

    int ans = dp[n][0][0] + dp[n][1][0];
    cout << ans << endl;

    return 0;
}