高速フーリエ変換 (FFT)

再帰型の FFT はやっぱり遅いので、非再帰型の FFT について調べた。『アルゴリズムイントロダクション 第 3 巻』の「多項式FFT」を参考にしている。多項式の次数は自由度(係数の個数)の意味で使っている。

多項式FFT

多項式を表現するとき、係数を持つ方法と点での値を持つ方法の 2 通りが考えられる。 点表現というのは $f(x_0),f(x_1),\ldots,f(x_{n-1})$ という $n$ 個の具体値で多項式を表現する方法である。 $(fg)(x_i)=f(x_i)g(x_i)$ なので点表現は乗算が $O(n)$ でできるという利点がある。

係数表現で乗算をするとき、係数表現から点表現への高速な変換(逆変換)があると嬉しい。 実は、多項式 $f$ の次数を $n$、$w=e^{2\pi/ni}$ としたとき、$f(w^0),f(w^1),f(w^2),\ldots,f(w^{n-1})$ は高速に求められる。 これを高速フーリエ変換 (FFT) と呼ぶ。

再帰FFT

多項式の次数は 2 の冪であるものとする。$w_n=e^{2\pi/ni}$ とする。$w_n^{i}$ での値を求めることを考えよう。

$$f(w_n^i)=a_0+a_1w_n^i+a_2w_n^{2i}+\cdots+a_{n-1}w_n^{(n-1)i}$$

偶数次と奇数次で分離する。

$$f(w_n^i)=(a_0+a_2w_n^{2i}+\cdots+a_{n-2}w_n^{(n-2)i}) +w_n^i(a_1+a_3w_n^{2i}+\cdots+a_{n-1}w_n^{(n-2)i})$$

$w_n^{2i}=w_{n/2}^{i}$であることを考えると次のように変形できる。

$$f(w_n^i)=(a_0+a_2w_{n/2}^{i}+\cdots+a_{n-2}w_{n/2}^{(n-2)/2i}) +w_n^i(a_1+a_3w_{n/2}^{i}+\cdots+a_{n-1}w_{n/2}^{(n-2)/2i})$$

これを見ると、多項式

$$ f_0(x)=a_0+a_2x+\cdots+a_{n-2}x^{(n-2)/2i}$$ $$ f_1(x)=a_1+a_3x+\cdots+a_{n-1}x^{(n-2)/2i}$$

について、単位円を $n/2$ 等分する点での値さえ分かれば計算できることが分かる。具体的には以下のようになる。

$$ f(w_n^i)=f_0(w_{n/2}^i)+w_n^{i}f_1(w_{n/2}^i) $$

$f_0,f_1$ は $n/2$ 次の多項式である。次数 $n$ の多項式FFT するのに掛かる時間を $T(n)$ とすると、$T(n)=2T(n/2)+O(n)=O(n \log n)$ となる。

再帰型を考える前に、もう少しだけこの式を変形する。$w_n^{i+n/2}=-w_n^{i}$ を使う。

$$ f(w_n^i)=f_0(w_{n/2}^i)+w_n^{i}f_1(w_{n/2}^i) $$ $$ f(w_n^{i+n/2})=f_0(w_{n/2}^i)-w_n^{i}f_1(w_{n/2}^i) $$

再帰FFT

再帰の様子を以下の図に示した。

f:id:pekempey:20161024162527p:plain

再帰で書くために、まず次のような置換をする。

$$(0,1,2,3,4,5,6,7)\to(0,4,2,6,1,5,3,7)$$ 何が起きてるのか分かりづらいが、二進表現で見てみると $$(000,001,010,011,100,101,110,111)\to(000,100,010,110,001,101,011,111)$$ となり、ビットを逆転させた位置に移動させているだけである。下位ビットの0/1で振り分けていることを考えれば何故こうなるのかは分かるだろう。

さて 4 段目から 3 段目への変換を考えてみよう。

f:id:pekempey:20161024163356p:plain

これは次のような計算になる。+は加算器、×は乗算器を表す。

f:id:pekempey:20161024165311p:plain

これは追加メモリなしで計算できる。

次に 3 段目から 2 段目への変換を考える。

f:id:pekempey:20161024170119p:plain

これも追加メモリなしで計算できる。

最後 2 段目から 1 段目への変換をくっつければ完成である。

f:id:pekempey:20161024171154p:plain

これはよく見る図。実際は一番最初にビット逆転をする。

AtCoder Typical Contest C 高速フーリエ変換 での実装例。298 ms。

#include <bits/stdc++.h>

std::vector<std::complex<double>> fft(std::vector<std::complex<double>> a, bool rev = false) {
    const double pi = std::acos(-1);
    int n = a.size();
    int h = 0;
    for (int i = 0; 1 << i < n; i++) h++;
    for (int i = 0; i < n; i++) {
        int j = 0;
        for (int k = 0; k < h; k++) j |= (i >> k & 1) << (h - 1 - k);
        if (i < j) std::swap(a[i], a[j]);
    }
    for (int i = 1; i < n; i *= 2) {
        for (int j = 0; j < i; j++) {
            std::complex<double> w = std::polar(1.0, 2 * pi / (i * 2) * (rev ? -1 : 1) * j);
            for (int k = 0; k < n; k += i * 2) {
                std::complex<double> s = a[j + k + 0];
                std::complex<double> t = a[j + k + i] * w;
                a[j + k + 0] = s + t;
                a[j + k + i] = s - t;
            }
        }
    }
    if (rev) for (int i = 0; i < n; i++) a[i] /= n;
    return a;
}

std::vector<int> mul(std::vector<int> a, std::vector<int> b) {
    int s = a.size() + b.size() - 1;
    int t = 1;
    while (t < s) t *= 2;
    std::vector<std::complex<double>> A(t), B(t);
    for (int i = 0; i < a.size(); i++) A[i].real(a[i]);
    for (int i = 0; i < b.size(); i++) B[i].real(b[i]);
    A = fft(A);
    B = fft(B);
    for (int i = 0; i < t; i++) A[i] *= B[i];
    A = fft(A, true);
    a.resize(s);
    for (int i = 0; i < s; i++) a[i] = round(A[i].real());
    return a;
}

int main() {
    int n;
    std::cin >> n;

    std::vector<int> a(n + 1), b(n + 1);

    for (int i = 1; i <= n; i++) {
        scanf("%d %d", &a[i], &b[i]);
    }

    a = mul(a, b);

    for (int i = 1; i <= 2 * n; i++) {
        printf("%d\n", a[i]);
    }
}
再帰は理解してたけど非再帰は空で書ける気がしなかったので理解した。spagetti source のライブラリの bit reverse のところがどうなってるのかが良く分かってなくて、いまでも良く分かってない。

おまけ

Number Theoretic Transform という mod 上の FFT みたいなのもある。サイズ n の NTT をするとき mod の原始 n 乗根が必要になるので、原始 220 乗根くらいがある mod でないと使えない。ただし畳み込みをするだけなら任意の mod でも可能。

typical contest の C は畳み込み後の最大値が 100×100×100000=109≦1012924417 なので、1012924417 を法とする NTT で計算しても結果が一致する。FFT より流石に速い。

#include <bits/stdc++.h>

struct NumberTheoreticTransform {
    int mod;
    int root;

    NumberTheoreticTransform(int mod, int root) : mod(mod), root(root) {}

    int mul(int x, int y) {
        return int64_t(x) * y % mod;
    }

    int add(int x, int y) {
        return (x += y) >= mod ? x - mod : x;
    }

    int pow(int x, int y) {
        int res = 1;
        while (y > 0) {
            if (y & 1) res = mul(res, x);
            x = mul(x, x);
            y >>= 1;
        }
        return res;
    }

    int inv(int x) {
        return pow(x, mod - 2);
    }

    void ntt(std::vector<int> &a, bool rev = false) {
        int n = a.size();
        int h = 0;
        for (int i = 0; 1 << i < n; i++) h++;
        for (int i = 0; i < n; i++) {
            int j = 0;
            for (int k = 0; k < h; k++) j |= (i >> k & 1) << (h - 1 - k);
            if (i < j) std::swap(a[i], a[j]);
        }
        for (int i = 1; i < n; i *= 2) {
            int w = pow(root, (mod - 1) / (i * 2));
            if (rev) w = inv(w);
            for (int j = 0; j < n; j += i * 2) {
                int wn = 1;
                for (int k = 0; k < i; k++) {
                    int s = a[j + k + 0];
                    int t = mul(a[j + k + i], wn);
                    a[j + k + 0] = add(s, t);
                    a[j + k + i] = add(s, mod - t);
                    wn = mul(wn, w);
                }
            }
        }
        int v = inv(n);
        if (rev) for (int i = 0; i < n; i++) a[i] = mul(a[i], v);
    }

    std::vector<int> mul(std::vector<int> a, std::vector<int> b) {
        int s = a.size() + b.size() - 1;
        int t = 1;
        while (t < s) t *= 2;
        a.resize(t);
        b.resize(t);
        ntt(a);
        ntt(b);
        for (int i = 0; i < t; i++) {
            a[i] = mul(a[i], b[i]);
        }
        ntt(a, true);
        a.resize(s);
        return a;
    }
};

int main() {
    int n;
    std::cin >> n;

    std::vector<int> a(n + 1), b(n + 1);

    for (int i = 1; i <= n; i++) {
        scanf("%d %d", &a[i], &b[i]);
    }

    NumberTheoreticTransform ntt(1012924417, 5);

    a = ntt.mul(a, b);

    for (int i = 1; i <= 2 * n; i++) {
        printf("%d\n", a[i]);
    }
}

変更履歴

2016/10/24 21:43:誤差に弱いコードだったみたいなので修正