FFT

(追記:rots の配置変えたら SIMD でやったのとそうでないのとで速度が全く同じになった…。切ない)

多項式乗算を行うだけのコード。サイズ n の 2基底FFT を n/4, n/4, n/4, n/4 の FFT に分割し、n/4 の FFTSIMD で4並列で行う。実装が簡単だったので実装した。

#pragma GCC optimize ("O3")
#pragma GCC target ("avx")
#include <iostream>
#include <algorithm>
#include <vector>
#include <complex>
#include <cstdio>
#include <cassert>
#include <immintrin.h>
#include <cmath>
#include <ctime>
#include <cstring>
#include <cassert>
#include <random>
 
using namespace std;
 
const int mod = 1e9 + 7;

mt19937 mt(123);

namespace vfft {
  struct Complex256 {
    __m256d x;
    __m256d y;
  };

  inline Complex256 operator+(Complex256 a, Complex256 b) {
    Complex256 res;
    res.x = _mm256_add_pd(a.x, b.x);
    res.y = _mm256_add_pd(a.y, b.y);
    return res;
  }

  inline Complex256 operator-(Complex256 a, Complex256 b) {
    Complex256 res;
    res.x = _mm256_sub_pd(a.x, b.x);
    res.y = _mm256_sub_pd(a.y, b.y);
    return res;
  }

  inline Complex256 operator*(Complex256 a, Complex256 b) {
    Complex256 res;
    res.x = _mm256_sub_pd(_mm256_mul_pd(a.x, b.x), _mm256_mul_pd(a.y, b.y));
    res.y = _mm256_add_pd(_mm256_mul_pd(a.x, b.y), _mm256_mul_pd(a.y, b.x));
    return res;
  }

  const int N = 1 << 21;
  const int M = N / 4;
  Complex256 vrots[M / 2];
  complex<double> rots[N / 2];

  void init() {
    const double pi = acos(-1);
    
    for (int i = 0; i < M / 2; i++) {
      vrots[i].x = _mm256_set1_pd(cos(2 * pi / M * i));
      vrots[i].y = _mm256_set1_pd(sin(2 * pi / M * i));
    }

    for (int i = 0; i < N / 2; i++) {
      rots[i] = polar(1.0, 2 * pi / N * i);
    }
  }

  void fft(vector<complex<double>> &a, bool rev) {
    const int n = a.size();
    if (n == 1) {
      return;
    }
    if (n == 2) {
      complex<double> s = a[0];
      complex<double> t = a[1];
      a[0] = s + t;
      a[1] = s - t;
      if (rev) {
        a[0] *= 0.5;
        a[1] *= 0.5;
      }
      return;
    }
    const int m = n / 4;
    assert(n <= N);

    int i = 0;
    for (int j = 1; j < n - 1; j++) {
      for (int k = n >> 1; k > (i ^= k); k >>= 1);
      if (j < i) {
        swap(a[i], a[j]);
      }
    }

    static Complex256 b[N / 4];
    for (int i = 0; i < m; i++) {
      b[i].x = _mm256_setr_pd(a[i].real(), a[i + m].real(), a[i + m * 2].real(), a[i + m * 3].real());
      b[i].y = _mm256_setr_pd(a[i].imag(), a[i + m].imag(), a[i + m * 2].imag(), a[i + m * 3].imag());
    }

    for (int i = 1; i < m; i *= 2) {
      for (int j = 0; j < m; j += i * 2) {
        for (int k = 0; k < i; k++) {
          auto s = b[j + k + 0];
          auto t = b[j + k + i] * vrots[M / 2 / i * k];
          b[j + k + 0] = s + t;
          b[j + k + i] = s - t;
        }
      }
    }

    for (int i = 0; i < m; i++) {
      alignas(32) double tx[4];
      alignas(32) double ty[4];
      _mm256_store_pd(tx, b[i].x);
      _mm256_store_pd(ty, b[i].y);
      for (int j = 0; j < 4; j++) {
        a[i + m * j].real(tx[j]);
        a[i + m * j].imag(ty[j]);
      }
    }

    for (int i = m; i < n; i *= 2) {
      for (int j = 0; j < n; j += i * 2) {
        for (int k = 0; k < i; k++) {
          auto s = a[j + k + 0];
          auto t = a[j + k + i] * rots[N / 2 / i * k];
          a[j + k + 0] = s + t;
          a[j + k + i] = s - t;
        }
      }
    }

    if (rev) {
      reverse(a.begin() + 1, a.end());
      for (int i = 0; i < n; i++) {
        a[i] *= 1.0 / n;
      }
    }
  }
}

namespace fft {
  const int N = 1 << 21;
  complex<double> rots[N / 2];
 
  void init() {
    const double pi = acos(-1);
 
    for (int i = 0; i < N / 2; i++) {
      rots[i].real(cos(2 * pi / N * i));
      rots[i].imag(sin(2 * pi / N * i));
    }
  }
 
  void fft(vector<complex<double>> &a, bool rev) {
    const int n = a.size();
 
    int i = 0;
    for (int j = 1; j < n - 1; j++) {
      for (int k = n >> 1; k > (i ^= k); k >>= 1);
      if (j < i) {
        swap(a[i], a[j]);
      }
    }
 
    for (int i = 1; i < n; i *= 2) {
      for (int j = 0; j < n; j += i * 2) {
        for (int k = 0; k < i; k++) {
          auto s = a[j + k + 0];
          auto t = a[j + k + i] * rots[N / 2 / i * k];
          a[j + k + 0] = s + t;
          a[j + k + i] = s - t;
        }
      }
    }
 
    if (rev) {
      reverse(a.begin() + 1, a.end());
      for (int i = 0; i < n; i++) {
        a[i] *= 1.0 / n;
      }
    }
  }
}

int main() {
  int n = 500000;
 
  const int N = 1 << 21;
  vector<complex<double>> a(N), b(N);
  for (int i = 1; i <= n; i++) {
    int x = uniform_int_distribution<int>(1, 10)(mt);
    int y = uniform_int_distribution<int>(1, 10)(mt);
    a[i].real(x);
    b[i].real(y);
  }
 
#ifdef VEC
  vfft::init();
  clock_t start = clock();
  vfft::fft(a, false);
  vfft::fft(b, false);
#else
  fft::init();
  clock_t start = clock();
  fft::fft(a, false);
  fft::fft(b, false);
#endif
 
  for (int i = 0; i < N; i++) {
    a[i] *= b[i];
  }
 
#ifdef VEC
  vfft::fft(a, true);
#else
  fft::fft(a, true);
#endif
  clock_t end = clock();

  printf("%.4f\n", (double)(end - start) / CLOCKS_PER_SEC);
}

計測対象は全部で 3 回ある FFT の最初から最後までで、入出力時間は含めない。あと、cos, sin のテーブルの前計算はある意味でネックだと思うけど、今回は計測対象にしない。

$ g++ a.cpp -O3 -mavx
$ ./a.out
1.0000
$ g++ a.cpp -O3 -mavx -DVEC
$ ./a.out
0.4062
$

手元だけじゃなくていくつか。計測対象時間は入出力を含めない。普通の方で、complex の掛け算を自作したものも載せておく。chef は変動幅が大きかったので当てにならないかも。

          NORMAL  SIMD    NORMAL(my mul)
----------------------------------------
My        0.4062  1.0000  0.7656
CodeChef  0.5966  1.3077  1.0879
Ideone    0.5077  0.8630  0.7129
AtCoder   0.4775  0.8146  0.5653

微妙に早くなっている程度…。

追記

rots を二分ヒープとかセグ木みたいな要領で配置すると、キャッシュに乗って速くなる。

#pragma GCC optimize ("O3")
#pragma GCC target ("avx")
#include <iostream>
#include <algorithm>
#include <vector>
#include <complex>
#include <cstdio>
#include <cassert>
#include <immintrin.h>
#include <cmath>
#include <ctime>
#include <cstring>
#include <cassert>
#include <random>
 
using namespace std;
 
const int mod = 1e9 + 7;

mt19937 mt(123);

inline complex<double> mul(complex<double> a, complex<double> b) {
  return complex<double>(
    a.real() * b.real() - a.imag() * b.imag(),
    a.real() * b.imag() + a.imag() * b.real()
  );
}

namespace vfft {
  struct Complex256 {
    __m256d x;
    __m256d y;
  };

  inline Complex256 operator+(Complex256 a, Complex256 b) {
    Complex256 res;
    res.x = _mm256_add_pd(a.x, b.x);
    res.y = _mm256_add_pd(a.y, b.y);
    return res;
  }

  inline Complex256 operator-(Complex256 a, Complex256 b) {
    Complex256 res;
    res.x = _mm256_sub_pd(a.x, b.x);
    res.y = _mm256_sub_pd(a.y, b.y);
    return res;
  }

  inline Complex256 operator*(Complex256 a, Complex256 b) {
    Complex256 res;
    res.x = _mm256_sub_pd(_mm256_mul_pd(a.x, b.x), _mm256_mul_pd(a.y, b.y));
    res.y = _mm256_add_pd(_mm256_mul_pd(a.x, b.y), _mm256_mul_pd(a.y, b.x));
    return res;
  }

  const int N = 1 << 21;
  const int M = N / 4;
  Complex256 vrots[M];
  complex<double> rots[N];

  void init() {
    const double pi = acos(-1);
    
    for (int i = 0; i < M / 2; i++) {
      vrots[i + M / 2].x = _mm256_set1_pd(cos(2 * pi / M * i));
      vrots[i + M / 2].y = _mm256_set1_pd(sin(2 * pi / M * i));
    }
    for (int i = M / 2 - 1; i >= 1; i--) {
      vrots[i] = vrots[i * 2];
    }

    for (int i = 0; i < N / 2; i++) {
      rots[i + N / 2] = polar(1.0, 2 * pi / N * i);
    }
    for (int i = N / 2 - 1; i >= 1; i--) {
      rots[i] = rots[i * 2];
    }
  }

  void fft(vector<complex<double>> &a, bool rev) {
    const int n = a.size();
    if (n == 1) {
      return;
    }
    if (n == 2) {
      complex<double> s = a[0];
      complex<double> t = a[1];
      a[0] = s + t;
      a[1] = s - t;
      if (rev) {
        a[0] *= 0.5;
        a[1] *= 0.5;
      }
      return;
    }
    const int m = n / 4;
    assert(n <= N);

    int i = 0;
    for (int j = 1; j < n - 1; j++) {
      for (int k = n >> 1; k > (i ^= k); k >>= 1);
      if (j < i) {
        swap(a[i], a[j]);
      }
    }

    static Complex256 b[N / 4];
    for (int i = 0; i < m; i++) {
      b[i].x = _mm256_setr_pd(a[i].real(), a[i + m].real(), a[i + m * 2].real(), a[i + m * 3].real());
      b[i].y = _mm256_setr_pd(a[i].imag(), a[i + m].imag(), a[i + m * 2].imag(), a[i + m * 3].imag());
    }

    for (int i = 1; i < m; i *= 2) {
      for (int j = 0; j < m; j += i * 2) {
        for (int k = 0; k < i; k++) {
          auto s = b[j + k + 0];
          auto t = b[j + k + i] * vrots[i + k];
          b[j + k + 0] = s + t;
          b[j + k + i] = s - t;
        }
      }
    }

    for (int i = 0; i < m; i++) {
      alignas(32) double tx[4];
      alignas(32) double ty[4];
      _mm256_store_pd(tx, b[i].x);
      _mm256_store_pd(ty, b[i].y);
      for (int j = 0; j < 4; j++) {
        a[i + m * j].real(tx[j]);
        a[i + m * j].imag(ty[j]);
      }
    }

    for (int i = m; i < n; i *= 2) {
      for (int j = 0; j < n; j += i * 2) {
        for (int k = 0; k < i; k++) {
          auto s = a[j + k + 0];
          auto t = mul(a[j + k + i], rots[i + k]);
          a[j + k + 0] = s + t;
          a[j + k + i] = s - t;
        }
      }
    }

    if (rev) {
      reverse(a.begin() + 1, a.end());
      for (int i = 0; i < n; i++) {
        a[i] *= 1.0 / n;
      }
    }
  }
}

namespace fft {
  const int N = 1 << 21;
  complex<double> rots[N];
 
  void init() {
    const double pi = acos(-1);
 
    for (int i = 0; i < N / 2; i++) {
      rots[i + N / 2].real(cos(2 * pi / N * i));
      rots[i + N / 2].imag(sin(2 * pi / N * i));
    }
    for (int i = N / 2 - 1; i >= 1; i--) {
      rots[i] = rots[i * 2];
    }
  }

  void fft(vector<complex<double>> &a, bool rev) {
    const int n = a.size();
 
    int i = 0;
    for (int j = 1; j < n - 1; j++) {
      for (int k = n >> 1; k > (i ^= k); k >>= 1);
      if (j < i) {
        swap(a[i], a[j]);
      }
    }
 
    for (int i = 1; i < n; i *= 2) {
      for (int j = 0; j < n; j += i * 2) {
        for (int k = 0; k < i; k++) {
          auto s = a[j + k + 0];
          auto t = mul(a[j + k + i], rots[i + k]);
          a[j + k + 0] = s + t;
          a[j + k + i] = s - t;
        }
      }
    }
 
    if (rev) {
      reverse(a.begin() + 1, a.end());
      for (int i = 0; i < n; i++) {
        a[i] *= 1.0 / n;
      }
    }
  }
}

int main() {
  int n = 1000000;
  // scanf("%d", &n);
 
  const int N = 1 << 21;
  vector<complex<double>> a(N), b(N);
  for (int i = 1; i <= n; i++) {
    // int x, y;
    // scanf("%d %d", &x, &y);
    int x = uniform_int_distribution<int>(1, 10)(mt);
    int y = uniform_int_distribution<int>(1, 10)(mt);
    a[i].real(x);
    b[i].real(y);
  }
 
#ifdef VEC
  vfft::init();
  clock_t start = clock();
  vfft::fft(a, false);
  vfft::fft(b, false);
#else
  fft::init();
  clock_t start = clock();
  fft::fft(a, false);
  fft::fft(b, false);
#endif
 
  for (int i = 0; i < N; i++) {
    a[i] *= b[i];
  }
 
#ifdef VEC
  vfft::fft(a, true);
#else
  fft::fft(a, true);
#endif
  clock_t end = clock();

  printf("%.4f\n", (double)(end - start) / CLOCKS_PER_SEC);

  // for (int i = 1; i <= 2 * n; i++) {
  //   printf("%d\n", (int)round(a[i].real()));
  // }
}
$ g++ a.cpp -O3 -mavx
$ ./a.out
0.2969
$ g++ a.cpp -O3 -mavx -DVEC
$ ./a.out
0.2969
$

うーん。コンパイラが優秀なのか SIMD がうまく機能していないのか、それともなにか。

環境

$ gcc --version
gcc (Debian 6.3.0-18+deb9u1) 6.3.0 20170516
Copyright (C) 2016 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

$ cat /proc/cpuinfo
processor       : 0
vendor_id       : GenuineIntel
cpu family      : 6
model           : 94
model name      : Intel(R) Core(TM) i5-6500 CPU @ 3.20GHz
stepping        : 3
microcode       : 0xffffffff
cpu MHz         : 3201.000
cache size      : 256 KB
physical id     : 0
siblings        : 4
core id         : 0
cpu cores       : 4
apicid          : 0
initial apicid  : 0
fpu             : yes
fpu_exception   : yes
cpuid level     : 6
wp              : yes
flags           : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave osxsave avx f16c rdrand
bogomips        : 6402.00
clflush size    : 64
cache_alignment : 64
address sizes   : 36 bits physical, 48 bits virtual
power management:

processor       : 1
vendor_id       : GenuineIntel
cpu family      : 6
model           : 94
model name      : Intel(R) Core(TM) i5-6500 CPU @ 3.20GHz
stepping        : 3
microcode       : 0xffffffff
cpu MHz         : 3201.000
cache size      : 256 KB
physical id     : 0
siblings        : 4
core id         : 1
cpu cores       : 4
apicid          : 0
initial apicid  : 0
fpu             : yes
fpu_exception   : yes
cpuid level     : 6
wp              : yes
flags           : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave osxsave avx f16c rdrand
bogomips        : 6402.00
clflush size    : 64
cache_alignment : 64
address sizes   : 36 bits physical, 48 bits virtual
power management:

processor       : 2
vendor_id       : GenuineIntel
cpu family      : 6
model           : 94
model name      : Intel(R) Core(TM) i5-6500 CPU @ 3.20GHz
stepping        : 3
microcode       : 0xffffffff
cpu MHz         : 3201.000
cache size      : 256 KB
physical id     : 0
siblings        : 4
core id         : 2
cpu cores       : 4
apicid          : 0
initial apicid  : 0
fpu             : yes
fpu_exception   : yes
cpuid level     : 6
wp              : yes
flags           : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave osxsave avx f16c rdrand
bogomips        : 6402.00
clflush size    : 64
cache_alignment : 64
address sizes   : 36 bits physical, 48 bits virtual
power management:

processor       : 3
vendor_id       : GenuineIntel
cpu family      : 6
model           : 94
model name      : Intel(R) Core(TM) i5-6500 CPU @ 3.20GHz
stepping        : 3
microcode       : 0xffffffff
cpu MHz         : 3201.000
cache size      : 256 KB
physical id     : 0
siblings        : 4
core id         : 3
cpu cores       : 4
apicid          : 0
initial apicid  : 0
fpu             : yes
fpu_exception   : yes
cpuid level     : 6
wp              : yes
flags           : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave osxsave avx f16c rdrand
bogomips        : 6402.00
clflush size    : 64
cache_alignment : 64
address sizes   : 36 bits physical, 48 bits virtual
power management:

$