# FFT

（追記：rots の配置変えたら SIMD でやったのとそうでないのとで速度が全く同じになった…。切ない）

#pragma GCC optimize ("O3")
#pragma GCC target ("avx")
#include <iostream>
#include <algorithm>
#include <vector>
#include <complex>
#include <cstdio>
#include <cassert>
#include <immintrin.h>
#include <cmath>
#include <ctime>
#include <cstring>
#include <cassert>
#include <random>

using namespace std;

const int mod = 1e9 + 7;

mt19937 mt(123);

namespace vfft {
struct Complex256 {
__m256d x;
__m256d y;
};

inline Complex256 operator+(Complex256 a, Complex256 b) {
Complex256 res;
return res;
}

inline Complex256 operator-(Complex256 a, Complex256 b) {
Complex256 res;
res.x = _mm256_sub_pd(a.x, b.x);
res.y = _mm256_sub_pd(a.y, b.y);
return res;
}

inline Complex256 operator*(Complex256 a, Complex256 b) {
Complex256 res;
res.x = _mm256_sub_pd(_mm256_mul_pd(a.x, b.x), _mm256_mul_pd(a.y, b.y));
res.y = _mm256_add_pd(_mm256_mul_pd(a.x, b.y), _mm256_mul_pd(a.y, b.x));
return res;
}

const int N = 1 << 21;
const int M = N / 4;
Complex256 vrots[M / 2];
complex<double> rots[N / 2];

void init() {
const double pi = acos(-1);

for (int i = 0; i < M / 2; i++) {
vrots[i].x = _mm256_set1_pd(cos(2 * pi / M * i));
vrots[i].y = _mm256_set1_pd(sin(2 * pi / M * i));
}

for (int i = 0; i < N / 2; i++) {
rots[i] = polar(1.0, 2 * pi / N * i);
}
}

void fft(vector<complex<double>> &a, bool rev) {
const int n = a.size();
if (n == 1) {
return;
}
if (n == 2) {
complex<double> s = a[0];
complex<double> t = a[1];
a[0] = s + t;
a[1] = s - t;
if (rev) {
a[0] *= 0.5;
a[1] *= 0.5;
}
return;
}
const int m = n / 4;
assert(n <= N);

int i = 0;
for (int j = 1; j < n - 1; j++) {
for (int k = n >> 1; k > (i ^= k); k >>= 1);
if (j < i) {
swap(a[i], a[j]);
}
}

static Complex256 b[N / 4];
for (int i = 0; i < m; i++) {
b[i].x = _mm256_setr_pd(a[i].real(), a[i + m].real(), a[i + m * 2].real(), a[i + m * 3].real());
b[i].y = _mm256_setr_pd(a[i].imag(), a[i + m].imag(), a[i + m * 2].imag(), a[i + m * 3].imag());
}

for (int i = 1; i < m; i *= 2) {
for (int j = 0; j < m; j += i * 2) {
for (int k = 0; k < i; k++) {
auto s = b[j + k + 0];
auto t = b[j + k + i] * vrots[M / 2 / i * k];
b[j + k + 0] = s + t;
b[j + k + i] = s - t;
}
}
}

for (int i = 0; i < m; i++) {
alignas(32) double tx[4];
alignas(32) double ty[4];
_mm256_store_pd(tx, b[i].x);
_mm256_store_pd(ty, b[i].y);
for (int j = 0; j < 4; j++) {
a[i + m * j].real(tx[j]);
a[i + m * j].imag(ty[j]);
}
}

for (int i = m; i < n; i *= 2) {
for (int j = 0; j < n; j += i * 2) {
for (int k = 0; k < i; k++) {
auto s = a[j + k + 0];
auto t = a[j + k + i] * rots[N / 2 / i * k];
a[j + k + 0] = s + t;
a[j + k + i] = s - t;
}
}
}

if (rev) {
reverse(a.begin() + 1, a.end());
for (int i = 0; i < n; i++) {
a[i] *= 1.0 / n;
}
}
}
}

namespace fft {
const int N = 1 << 21;
complex<double> rots[N / 2];

void init() {
const double pi = acos(-1);

for (int i = 0; i < N / 2; i++) {
rots[i].real(cos(2 * pi / N * i));
rots[i].imag(sin(2 * pi / N * i));
}
}

void fft(vector<complex<double>> &a, bool rev) {
const int n = a.size();

int i = 0;
for (int j = 1; j < n - 1; j++) {
for (int k = n >> 1; k > (i ^= k); k >>= 1);
if (j < i) {
swap(a[i], a[j]);
}
}

for (int i = 1; i < n; i *= 2) {
for (int j = 0; j < n; j += i * 2) {
for (int k = 0; k < i; k++) {
auto s = a[j + k + 0];
auto t = a[j + k + i] * rots[N / 2 / i * k];
a[j + k + 0] = s + t;
a[j + k + i] = s - t;
}
}
}

if (rev) {
reverse(a.begin() + 1, a.end());
for (int i = 0; i < n; i++) {
a[i] *= 1.0 / n;
}
}
}
}

int main() {
int n = 500000;

const int N = 1 << 21;
vector<complex<double>> a(N), b(N);
for (int i = 1; i <= n; i++) {
int x = uniform_int_distribution<int>(1, 10)(mt);
int y = uniform_int_distribution<int>(1, 10)(mt);
a[i].real(x);
b[i].real(y);
}

#ifdef VEC
vfft::init();
clock_t start = clock();
vfft::fft(a, false);
vfft::fft(b, false);
#else
fft::init();
clock_t start = clock();
fft::fft(a, false);
fft::fft(b, false);
#endif

for (int i = 0; i < N; i++) {
a[i] *= b[i];
}

#ifdef VEC
vfft::fft(a, true);
#else
fft::fft(a, true);
#endif
clock_t end = clock();

printf("%.4f\n", (double)(end - start) / CLOCKS_PER_SEC);
}


$g++ a.cpp -O3 -mavx$ ./a.out
1.0000
$g++ a.cpp -O3 -mavx -DVEC$ ./a.out
0.4062
 手元だけじゃなくていくつか。計測対象時間は入出力を含めない。普通の方で、complex の掛け算を自作したものも載せておく。chef は変動幅が大きかったので当てにならないかも。  NORMAL SIMD NORMAL(my mul) ---------------------------------------- My 0.4062 1.0000 0.7656 CodeChef 0.5966 1.3077 1.0879 Ideone 0.5077 0.8630 0.7129 AtCoder 0.4775 0.8146 0.5653 微妙に早くなっている程度…。 ### 追記 rots を二分ヒープとかセグ木みたいな要領で配置すると、キャッシュに乗って速くなる。 #pragma GCC optimize ("O3") #pragma GCC target ("avx") #include <iostream> #include <algorithm> #include <vector> #include <complex> #include <cstdio> #include <cassert> #include <immintrin.h> #include <cmath> #include <ctime> #include <cstring> #include <cassert> #include <random> using namespace std; const int mod = 1e9 + 7; mt19937 mt(123); inline complex<double> mul(complex<double> a, complex<double> b) { return complex<double>( a.real() * b.real() - a.imag() * b.imag(), a.real() * b.imag() + a.imag() * b.real() ); } namespace vfft { struct Complex256 { __m256d x; __m256d y; }; inline Complex256 operator+(Complex256 a, Complex256 b) { Complex256 res; res.x = _mm256_add_pd(a.x, b.x); res.y = _mm256_add_pd(a.y, b.y); return res; } inline Complex256 operator-(Complex256 a, Complex256 b) { Complex256 res; res.x = _mm256_sub_pd(a.x, b.x); res.y = _mm256_sub_pd(a.y, b.y); return res; } inline Complex256 operator*(Complex256 a, Complex256 b) { Complex256 res; res.x = _mm256_sub_pd(_mm256_mul_pd(a.x, b.x), _mm256_mul_pd(a.y, b.y)); res.y = _mm256_add_pd(_mm256_mul_pd(a.x, b.y), _mm256_mul_pd(a.y, b.x)); return res; } const int N = 1 << 21; const int M = N / 4; Complex256 vrots[M]; complex<double> rots[N]; void init() { const double pi = acos(-1); for (int i = 0; i < M / 2; i++) { vrots[i + M / 2].x = _mm256_set1_pd(cos(2 * pi / M * i)); vrots[i + M / 2].y = _mm256_set1_pd(sin(2 * pi / M * i)); } for (int i = M / 2 - 1; i >= 1; i--) { vrots[i] = vrots[i * 2]; } for (int i = 0; i < N / 2; i++) { rots[i + N / 2] = polar(1.0, 2 * pi / N * i); } for (int i = N / 2 - 1; i >= 1; i--) { rots[i] = rots[i * 2]; } } void fft(vector<complex<double>> &a, bool rev) { const int n = a.size(); if (n == 1) { return; } if (n == 2) { complex<double> s = a[0]; complex<double> t = a[1]; a[0] = s + t; a[1] = s - t; if (rev) { a[0] *= 0.5; a[1] *= 0.5; } return; } const int m = n / 4; assert(n <= N); int i = 0; for (int j = 1; j < n - 1; j++) { for (int k = n >> 1; k > (i ^= k); k >>= 1); if (j < i) { swap(a[i], a[j]); } } static Complex256 b[N / 4]; for (int i = 0; i < m; i++) { b[i].x = _mm256_setr_pd(a[i].real(), a[i + m].real(), a[i + m * 2].real(), a[i + m * 3].real()); b[i].y = _mm256_setr_pd(a[i].imag(), a[i + m].imag(), a[i + m * 2].imag(), a[i + m * 3].imag()); } for (int i = 1; i < m; i *= 2) { for (int j = 0; j < m; j += i * 2) { for (int k = 0; k < i; k++) { auto s = b[j + k + 0]; auto t = b[j + k + i] * vrots[i + k]; b[j + k + 0] = s + t; b[j + k + i] = s - t; } } } for (int i = 0; i < m; i++) { alignas(32) double tx[4]; alignas(32) double ty[4]; _mm256_store_pd(tx, b[i].x); _mm256_store_pd(ty, b[i].y); for (int j = 0; j < 4; j++) { a[i + m * j].real(tx[j]); a[i + m * j].imag(ty[j]); } } for (int i = m; i < n; i *= 2) { for (int j = 0; j < n; j += i * 2) { for (int k = 0; k < i; k++) { auto s = a[j + k + 0]; auto t = mul(a[j + k + i], rots[i + k]); a[j + k + 0] = s + t; a[j + k + i] = s - t; } } } if (rev) { reverse(a.begin() + 1, a.end()); for (int i = 0; i < n; i++) { a[i] *= 1.0 / n; } } } } namespace fft { const int N = 1 << 21; complex<double> rots[N]; void init() { const double pi = acos(-1); for (int i = 0; i < N / 2; i++) { rots[i + N / 2].real(cos(2 * pi / N * i)); rots[i + N / 2].imag(sin(2 * pi / N * i)); } for (int i = N / 2 - 1; i >= 1; i--) { rots[i] = rots[i * 2]; } } void fft(vector<complex<double>> &a, bool rev) { const int n = a.size(); int i = 0; for (int j = 1; j < n - 1; j++) { for (int k = n >> 1; k > (i ^= k); k >>= 1); if (j < i) { swap(a[i], a[j]); } } for (int i = 1; i < n; i *= 2) { for (int j = 0; j < n; j += i * 2) { for (int k = 0; k < i; k++) { auto s = a[j + k + 0]; auto t = mul(a[j + k + i], rots[i + k]); a[j + k + 0] = s + t; a[j + k + i] = s - t; } } } if (rev) { reverse(a.begin() + 1, a.end()); for (int i = 0; i < n; i++) { a[i] *= 1.0 / n; } } } } int main() { int n = 1000000; // scanf("%d", &n); const int N = 1 << 21; vector<complex<double>> a(N), b(N); for (int i = 1; i <= n; i++) { // int x, y; // scanf("%d %d", &x, &y); int x = uniform_int_distribution<int>(1, 10)(mt); int y = uniform_int_distribution<int>(1, 10)(mt); a[i].real(x); b[i].real(y); } #ifdef VEC vfft::init(); clock_t start = clock(); vfft::fft(a, false); vfft::fft(b, false); #else fft::init(); clock_t start = clock(); fft::fft(a, false); fft::fft(b, false); #endif for (int i = 0; i < N; i++) { a[i] *= b[i]; } #ifdef VEC vfft::fft(a, true); #else fft::fft(a, true); #endif clock_t end = clock(); printf("%.4f\n", (double)(end - start) / CLOCKS_PER_SEC); // for (int i = 1; i <= 2 * n; i++) { // printf("%d\n", (int)round(a[i].real())); // } }   g++ a.cpp -O3 -mavx
$./a.out 0.2969$ g++ a.cpp -O3 -mavx -DVEC
$./a.out 0.2969$

うーん。コンパイラが優秀なのか SIMD がうまく機能していないのか、それともなにか。

### 環境

$gcc --version gcc (Debian 6.3.0-18+deb9u1) 6.3.0 20170516 Copyright (C) 2016 Free Software Foundation, Inc. This is free software; see the source for copying conditions. There is NO warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.$ cat /proc/cpuinfo
processor       : 0
vendor_id       : GenuineIntel
cpu family      : 6
model           : 94
model name      : Intel(R) Core(TM) i5-6500 CPU @ 3.20GHz
stepping        : 3
microcode       : 0xffffffff
cpu MHz         : 3201.000
cache size      : 256 KB
physical id     : 0
siblings        : 4
core id         : 0
cpu cores       : 4
apicid          : 0
initial apicid  : 0
fpu             : yes
fpu_exception   : yes
cpuid level     : 6
wp              : yes
flags           : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave osxsave avx f16c rdrand
bogomips        : 6402.00
clflush size    : 64
cache_alignment : 64
address sizes   : 36 bits physical, 48 bits virtual
power management:

processor       : 1
vendor_id       : GenuineIntel
cpu family      : 6
model           : 94
model name      : Intel(R) Core(TM) i5-6500 CPU @ 3.20GHz
stepping        : 3
microcode       : 0xffffffff
cpu MHz         : 3201.000
cache size      : 256 KB
physical id     : 0
siblings        : 4
core id         : 1
cpu cores       : 4
apicid          : 0
initial apicid  : 0
fpu             : yes
fpu_exception   : yes
cpuid level     : 6
wp              : yes
flags           : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave osxsave avx f16c rdrand
bogomips        : 6402.00
clflush size    : 64
cache_alignment : 64
address sizes   : 36 bits physical, 48 bits virtual
power management:

processor       : 2
vendor_id       : GenuineIntel
cpu family      : 6
model           : 94
model name      : Intel(R) Core(TM) i5-6500 CPU @ 3.20GHz
stepping        : 3
microcode       : 0xffffffff
cpu MHz         : 3201.000
cache size      : 256 KB
physical id     : 0
siblings        : 4
core id         : 2
cpu cores       : 4
apicid          : 0
initial apicid  : 0
fpu             : yes
fpu_exception   : yes
cpuid level     : 6
wp              : yes
flags           : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave osxsave avx f16c rdrand
bogomips        : 6402.00
clflush size    : 64
cache_alignment : 64
address sizes   : 36 bits physical, 48 bits virtual
power management:

processor       : 3
vendor_id       : GenuineIntel
cpu family      : 6
model           : 94
model name      : Intel(R) Core(TM) i5-6500 CPU @ 3.20GHz
stepping        : 3
microcode       : 0xffffffff
cpu MHz         : 3201.000
cache size      : 256 KB
physical id     : 0
siblings        : 4
core id         : 3
cpu cores       : 4
apicid          : 0
initial apicid  : 0
fpu             : yes
fpu_exception   : yes
cpuid level     : 6
wp              : yes
flags           : fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave osxsave avx f16c rdrand
bogomips        : 6402.00
clflush size    : 64
cache_alignment : 64
address sizes   : 36 bits physical, 48 bits virtual
power management:

\$