https://yukicoder.me/problems/no/569
解法そのものは自明だが、真面目にやるべきものではない。行列による解法の存在に気づければ、解は線形漸化式に従うことが分かるので、その性質を使って楽をする。
解の列に対する母関数 a(x) を考える。ここである多項式 f(x) が存在して、f(x)a(x) の次数が L で抑えられる(つまり次数が高い場所では係数が 0 になる)。このような多項式を線形回帰列に対する annihilator と呼ぶらしい。annihilator を求めるアルゴリズムとして Berlekamp-Massey アルゴリズムが知られている [1]。具体値が既知で漸化式が未知のときに漸化式を求めるアルゴリズムとも考えられる。
annihilator を計算する際に、数列の具体値を計算する必要がある。これは graphillion などの ZDD 関連のライブラリを用いて計算できるが、自分が書いたコードは何故か遅かった(何かテクニックがあるのだろうか?)。以前に simpath を書いたことがあったので、そちらを使うことにする [2]。フロンティアが極めて小さいため、各フェーズにおける状態数は10~20個程度にしかならない。そのため、n=30 程度までなら一瞬で計算できる。試しに n=10^5 で計算させてみたところ 1.58 秒で結果が得られた。
#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
#include <set>
#include <cassert>
using namespace std;
constexpr int mod = 1e9 + 7;
struct modint {
int n;
modint(int n = 0) : n(n) {}
};
modint operator+(modint a, modint b) { return modint((a.n += b.n) >= mod ? a.n - mod : a.n); }
modint operator-(modint a, modint b) { return modint((a.n -= b.n) < 0 ? a.n + mod : a.n); }
modint operator*(modint a, modint b) { return modint(1LL * a.n * b.n % mod); }
bool operator<(modint a, modint b) { return a.n < b.n; }
modint &operator+=(modint &a, modint b) { return a = a + b; }
modint &operator-=(modint &a, modint b) { return a = a - b; }
modint &operator*=(modint &a, modint b) { return a = a * b; }
modint modinv(modint n) {
if (n.n == 1) return 1;
return modinv(mod % n.n) * (mod - mod / n.n);
}
modint operator/(modint a, modint b) { return a * modinv(b); }
modint simpath(int n) {
constexpr int H = 4;
vector<pair<int, int>> es;
for (int x = 0; x < n; x++) {
for (int y = 0; y < H; y++) {
if (y + 1 < H) es.emplace_back(x * H + y, x * H + y + 1);
if (x + 1 < n) es.emplace_back(x * H + y, x * H + y + H);
}
}
vector<int> deg(n * H);
for (auto e : es) {
deg[e.first]++;
deg[e.second]++;
}
map<map<int, int>, modint> dp0, dp1;
dp0[{{0, 0}}] = 1;
for (auto e : es) {
const int u = e.first;
const int v = e.second;
deg[u]--;
deg[v]--;
dp1.clear();
for (auto kv : dp0) {
map<int, int> mate1 = kv.first;
if (!kv.first.count(v)) mate1[v] = v;
const int mu = mate1[u];
const int mv = mate1[v];
const bool del = deg[u] == 0 && u != 0;
if (del) mate1.erase(u);
if (deg[u] != 0 || (u != 0 && mu == u) || (u == 0 && mu != u) || mu == -1) dp1[mate1] += kv.second;
if (mu == v || mu == -1 || mv == -1) continue;
mate1[u] = -1;
mate1[v] = -1;
mate1[mu] = mv;
mate1[mv] = mu;
if (mate1[0] == -1) continue;
if (del) mate1.erase(u);
if (!del || mu != u) dp1[mate1] += kv.second;
}
swap(dp0, dp1);
}
return dp0.rbegin()->second;
}
vector<modint> berlekamp_massey(vector<modint> s) {
const int N = s.size();
vector<modint> C(N);
vector<modint> B(N);
C[0] = 1;
B[0] = 1;
int L = 0;
int m = 1;
modint b = 1;
for (int n = 0; n < N; n++) {
modint d = s[n];
for (int i = 1; i <= L; i++) d += C[i] * s[n - i];
if (d.n == 0) {
m++;
} else if (2 * L <= n) {
auto T = C;
for (int i = 0; i + m < N; i++) C[i + m] -= B[i] * (d / b);
L = n + 1 - L;
B = T;
b = d;
m = 1;
} else {
for (int i = 0; i + m < N; i++) C[i + m] -= B[i] * (d / b);
m++;
}
}
C.resize(L + 1);
reverse(C.begin(), C.end());
assert(L < N - 1);
return C;
}
vector<modint> poly_mod(vector<modint> a, const vector<modint> &m) {
const int n = m.size();
for (int i = a.size() - 1; i >= m.size(); i--) {
for (int j = 0; j < m.size(); j++) {
a[i - n + j] += a[i] * m[j];
}
}
a.resize(m.size());
return a;
}
vector<modint> poly_mul(const vector<modint> &a, const vector<modint> &b, const vector<modint> &m) {
vector<modint> ret(a.size() + b.size() - 1);
for (int i = 0; i < a.size(); i++) {
for (int j = 0; j < b.size(); j++) {
ret[i + j] += a[i] * b[j];
}
}
return poly_mod(ret, m);
}
vector<modint> nth_power(long long n, const vector<modint> &m) {
vector<modint> ret(1);
vector<modint> x(2);
ret[0] = x[1] = 1;
while (n > 0) {
if (n & 1) ret = poly_mul(ret, x, m);
x = poly_mul(x, x, m);
n /= 2;
}
return poly_mod(ret, m);
}
int main() {
vector<modint> a(30);
for (int i = 0; i < a.size(); i++) {
a[i] = simpath(i + 1);
}
vector<modint> m = berlekamp_massey(a);
m.pop_back();
for (int i = 0; i < m.size(); i++) {
m[i] *= mod - 1;
}
long long n;
cin >> n;
auto x = nth_power(n, m);
modint ans;
for (int i = 0; i < x.size(); i++) {
ans += x[i] * a[i];
}
cout << ans.n << endl;
}
問題に関係のない simpath 関連のコードは以下の gist にまとめた。速さよりもコンパクトさを求めて書いた。まだ短くなる気がするので、劇的な改善策を見つけたい。コードもまだ綺麗とは言えない状態。『超高速グラフ列挙アルゴリズム』をまだ読んでないんだけど、もしかしたらコード、あるいは実装方法が載ってたりする?流石に載ってない?
https://gist.github.com/pekempey/80fb0978306cc077afa2800dd214cddd
[1] https://en.wikipedia.org/wiki/Berlekamp%E2%80%93Massey_algorithm
[2] http://pekempey.hatenablog.com/entry/2017/01/26/203424