Peter's codebook A blog full of codes

Template -- FFT

FFT 模板二代,簡單易用、精度高

驗證模板正確性 Uva12298

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#pragma GCC target ("avx")
#pragma GCC optimize ("O3")
#include <algorithm>
#include <iostream>
#include <cmath>
#include <cstdio>
#include <cstring>
using namespace std;

namespace FFT_TOOL {
    // 若常數 long double M_PIl 沒有定義,自行定義一個常數
#ifndef M_PIl
    const long double M_PIl = std::acos(-1.0L);
#endif
    template <class T> struct Complex {
        T r, i;
        Complex(T _r = 0, T _i = 0) : r(_r), i(_i) {}
        Complex<T> operator+(const Complex<T> &b) { return Complex(r + b.r, i + b.i); } // 用 member funcion 來完成(因為不想放進 namespace 裡)
        Complex<T> operator-(const Complex<T> &b) { return Complex(r - b.r, i - b.i); }
        // 嘗試使用 --ffast-math
        Complex<T> operator*(const Complex<T> &b) __attribute__((optimize("fast-math")));
    };
    template <class T>
        Complex<T> Complex<T>::operator*(const Complex<T> &b) {
            return Complex<T>(r * b.r - i * b.i, r * b.i + i * b.r);
        }
    class FFT { // static class fft
        private:
            template <class T>
                static void change(Complex<T> y[], int len) {
                    int i, j, k;
                    for (i = 1, j = (len>>1); i < len - 1; ++i) {
                        if (i < j) swap(y[i], y[j]);
                        k = (len>>1);
                        while (j >= k) {
                            j -= k;
                            k >>= 1;
                        }
                        if (j < k) j += k;
                    }
                }
            // 嘗試使用 --ffast-math
            template <class T> static void fft(Complex<T> y[], int len, int inv) __attribute__((optimize("fast-math")));
        public:
            // 介面,輸入向量 y, y 會變成 FFT(y) / IFFT(y)
            // 參數: 向量y, 長度(2^k), IFFT?
            template <class T> static void run(Complex<T> y[], int l, bool inv = false) {
                fft(y, l, inv ? -1 : 1);
            }
    };
    template <class T>
        void FFT::fft(Complex<T> y[], int len, int inv) { // if inv:1 FFT; int:-1 IFFT
            change(y, len);
            for (int h = 2; h <= len; h <<= 1) {
                Complex<T> wn(std::cos(-inv * 2 * M_PIl / h), std::sin(-inv * 2 * M_PIl / h));
                for (int j = 0; j < len; j += h) {
                    Complex<T> w(1, 0);
                    for (int k = j; k < j + h / 2; k++) {
                        Complex<T> u = y[k];
                        Complex<T> t = w * y[k + h / 2];
                        y[k] = u + t;
                        y[k + h / 2] = u - t;
                        w = w * wn;
                    }
                }
            }
            if (inv == -1)
                for (int i = 0; i < len; i++)
                    y[i].r /= len;
        }

}; // namespace FFT

#include <cstring>

using namespace std;
using namespace FFT_TOOL;

const size_t MAXN = 262144;
int prime_list[50010];
Complex<long double> card[4][MAXN];

void mk_prime(void) {
    memset(prime_list, 0x00, sizeof(prime_list));
    for (int i = 2; i < 50010; ++i) {
        if (!prime_list[i]) {
            for (int j = i + i; j < 50010; j += i) {
                prime_list[j] = 1;
            }
        }
    }
}

int main(void) {
    int a, b, c, d;
    mk_prime();
    while (scanf("%d%d%d", &a, &b, &c) && a + b + c) {
        int d = 2;
        for (; (b + 1) * 4 > d; d <<= 1);
        for (int i = 0; i < 4; ++i) {
            for (int j = 0; j < d; ++j) {
                card[i][j] = Complex<long double>(j <= b ? prime_list[j] : 0, 0);
            }
        }
        for (int i = 0; i < c; ++i) {
            int v; char c;
            scanf("%d%c", &v, &c);
            if (c == 'S') card[0][v] = 0;
            if (c == 'H') card[1][v] = 0;
            if (c == 'C') card[2][v] = 0;
            if (c == 'D') card[3][v] = 0;
        }
        for (int i = 0; i < 4; ++i) FFT::run(card[i], d, false);
        for (int j = 0; j < d; ++j) card[0][j] = card[0][j] * card[1][j] * card[2][j] * card[3][j];
        FFT::run(card[0], d, true);
        for (int i = a; i <= b; ++i)
            printf("%lld\n", (long long)(card[0][i].r + 0.5L));
        puts("");
    }
    return 0;
}

FFT 模板初代,這個版本的精度比較差

題目驗證:ZOJ1637

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
// 4041158, AC, 100ms, 22832k
// #pragma GCC target ("avx")
// 要確認 Judge 有支援這一個指令集,才可以開
#pragma GCC optimize ("Os")
#pragma GCC optimize ("O3")
#include <iostream>
#include <algorithm>
#include <vector>
#include <complex>
#include <cmath>
#include <cstring>
//#include <cstdint>
#ifdef M_PIl
#undef M_PIl
#endif
// PI 常數,也可以使用 acos(-1.) 啦
#define M_PIl 3.1415926535897932384626433832795028841971693993751058209749445923078164062862089986280348253421170679
// 預防 for(int i; ...) 的 i 跑到 scope 外面
#define for if (0); else for

// 確認有 uint32 這個 type
#ifndef uint32_t
#define uint32_t unsigned long int
#endif

namespace FFT  // WARNING!!! do not reveal this namespace
{
    using namespace std; // be careful

    int NumberOfBitsNeeded(int PowerOfTwo) {
        for (int i = 0;; ++i) {
            if (PowerOfTwo & (1 << i)) {
                return i;
            }
        }
    }

    // 參考 morris + 掛長 的 blog 的快速翻轉 bit (MSB->LSB; LSB->MSB)
    inline uint32_t FastReverseBits(uint32_t a, int NumBits) {
        a = ( ( a & 0x55555555U ) << 1 ) | ( ( a & 0xAAAAAAAAU ) >> 1 ) ;
        a = ( ( a & 0x33333333U ) << 2 ) | ( ( a & 0xCCCCCCCCU ) >> 2 ) ;
        a = ( ( a & 0x0F0F0F0FU ) << 4 ) | ( ( a & 0xF0F0F0F0U ) >> 4 ) ;
        a = ( ( a & 0x00FF00FFU ) << 8 ) | ( ( a & 0xFF00FF00U ) >> 8 ) ;
        a = ( ( a & 0x0000FFFFU ) << 16 ) | ( ( a & 0xFFFF0000U ) >> 16 ) ;
        return a >> (32 - NumBits);
    }

    // 嘗試打開 fast-math 優化選項,如果 Judge 機不支援,請記得把這行註解掉
    void FFT(bool, const vector<complex<double> >&, vector<complex<double> >&) __attribute__((optimize("fast-math")));
    // FFT 本體, In 是輸入的向量(訊號),Out 是輸出的向量(訊號)
    // 這裏其實不太重要,主要會用得的部分是下方的卷積
    // NOTE:::::::::::: 兩個向量長度必須是 2^k,等長
    void FFT(bool InverseTransform, const vector<complex<double> >& In, vector<complex<double> >& Out) {
        // simultaneous data copy and bit-reversal ordering into outputs
        int NumSamples = In.size();
        int NumBits = NumberOfBitsNeeded(NumSamples);
        for (int i = 0; i < NumSamples; ++i) {
            Out[FastReverseBits((uint32_t)i, NumBits)] = In[i];
        }
        // the FFT process
        double angle_numerator = M_PIl * (InverseTransform ? -2 : 2);
        for (int BlockEnd = 1, BlockSize = 2; BlockSize <= NumSamples; BlockSize <<= 1) {
            double ndelta_angle = -(angle_numerator / BlockSize);
            double sin1 = sin(ndelta_angle);
            double cos1 = cos(ndelta_angle);
            double sin2 = 2*sin1*cos1; // 注意:用合角公式計算有可能放大誤差,這是為了搶速度。
            double cos2 = 2*cos1*cos1-1.0;
            for (int i = 0; i < NumSamples; i += BlockSize) {
                complex<double> a1(cos1, sin1), a2(cos2, sin2);
                for (int j = i, n = 0; n < BlockEnd; ++j, ++n) {
                    complex<double> a0(2 * cos1 * a1.real() - a2.real(), 2 * cos1 * a1.imag() - a2.imag());
                    a2 = a1;
                    a1 = a0;
                    complex<double> a = a0 * Out[j + BlockEnd];
                    Out[j + BlockEnd] = Out[j] - a;
                    Out[j] += a;
                }
            }
            BlockEnd = BlockSize;
        }
        // normalize if inverse transform
        if (InverseTransform) {
            for (int i = 0; i < NumSamples; ++i) {
                Out[i] /= NumSamples;
            }
        }
    }

    // 同上,如果編譯器不支援 fast-math 選項,記得註解掉下面兩行
    template<class T>
        void convolution(const vector<T> &a, const vector<T> &b, vector<double> &ret) __attribute__((optimize("fast-math")));
    // 卷積,輸入"""等長"""的 a, b 兩向量(長度必須是 2^k),會得到 a * b (a卷積b)的結果
    // 下面套用 ZOJ 上的例子:
    /**
     * Given two sequences {a1, a2, a3.. an} and {b1, b2, b3... bn},
     * their repeat convolution means:
     * r1 = a1*b1 + a2*b2 + a3*b3 + ... + an*bn
     * r2 = a1*bn + a2*b1 + a3*b2 + ... + an*bn-1
     * r3 = a1*bn-1 + a2*bn + a3*b1 + ... + an*bn-2
     * ...
     * rn = a1*b2 + a2*b3 + a3*b4 + ... + an-1*bn + an*b1
     * Notice n >= 2 and n must be power of 2.
     */
    template<class T> // a, b 兩向量的輸入型態
        void convolution(const vector<T> &a, const vector<T> &b, vector<double> &ret) {
            int n = a.size();
            vector<complex<double> > s(n), d1(n), d2(n), y(n);
            for (int i = 0; i < n; ++i) {
                s[i] = complex<double>(a[i], 0);
            }
            FFT(false, s, d1);
            s[0] = complex<double>(b[0], 0);
            for (int i = 1; i < n; ++i) {
                s[i] = complex<double>(b[n - i], 0);
            }
            FFT(false, s, d2);
            for (int i = 0; i < n; ++i) {
                y[i] = d1[i] * d2[i];
            }
            FFT(true, y, s);
            ret.resize(n,0);
            for (int i = 0; i < n; ++i) {
                ret[i] = s[i].real();
            }
        }
    // 同上,如果編譯器不支援 fast-math 選項,記得註解掉下面兩行
    template<class T>
        void mul(const vector<T> &a, const vector<T> &b, vector<double> &ret) __attribute__((optimize("fast-math")));
    // 下面的流程 A*B -> a=FFT(A); b=FFT(B) -> c = a dot b -> ANS = IFFT(c)
    template<class T>
        void mul(const vector<T> &a, const vector<T> &b, vector<double> &ret) {
            int n = a.size();
            vector<complex<double> > s(n), d1(n), d2(n), y(n);
            for (int i = 0; i < n; ++i) {
                s[i] = complex<double>(a[i], 0);
            }
            FFT(false, s, d1);
            for (int i = 0; i < n; ++i) {
                s[i] = complex<double>(b[i], 0);
            }
            FFT(false, s, d2);
            for (int i = 0; i < n; ++i) {
                y[i] = d1[i] * d2[i];
            }
            FFT(true, y, s);
            ret.resize(n,0);
            for (int i = 0; i < n; ++i) {
                ret[i] = s[i].real();
            }
        }
}; // namespace FFT

double area[512][512];

int main(void) {
    using namespace std;
    int Ah, Aw, Bh, Bw;
    int W, H;
    vector<int> A(262144);
    vector<int> B(262144);
    vector<double> conv(262144);
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    int T; cin >> T;
    while (T--) {
        cin >> Ah >> Aw >> Bh >> Bw;
        W = max(Aw,Bw);
        H = max(Ah,Bh);
        int n = H*W, n2base=1;
        for (n2base=1; n>n2base; n2base<<=1); // n2base==2^k , n2base>=n
        A.resize(n2base); B.resize(n2base);
        fill(A.begin(), A.end(), 0);
        fill(B.begin(), B.end(), 0);
        for (int i=0; i<Ah; ++i) {
            for (int j=0; j<Aw; ++j) {
                cin >> A.at(i*W+j);
            }
        }
        for (int i=0; i<Bh; ++i) {
            for (int j=0; j<Bw; ++j) {
                cin >> B.at(i*W+j);
            }
        }
        FFT::convolution<int>(A, B, conv);
#ifdef DEBUG
        for (int i=0; i<conv.size(); ++i) {
            int _i=i/W;
            int _j=i%W;
            if (_i+Bh>Ah||_j+Bw>Aw) continue;
            cout << "conv[" << (i/W) << "][" << (i%W) << "] = " << conv.at(i) << endl;
        }
#endif
        memset(area, 0x00, sizeof(area));
        for (int i=1; i<=Ah; ++i) {
            double sum=0.0;
            for (int j=1; j<=Aw; ++j) {
                sum += A[(i-1)*W+(j-1)]*A[(i-1)*W+(j-1)];
                area[i][j] = sum + area[i-1][j]; // 2D prefix sum
            }
        }
        double minval=1e300;
        int x=0, y=0;
        for (int i=0; i+Bh<=Ah; ++i) {
            for (int j=0; j+Bw<=Aw; ++j) {
                double diff=area[i+Bh][j+Bw] - area[i+Bh][j] - area[i][j+Bw] + area[i][j];
#ifdef DEBUG
                cout << "A2[" << i << "][" << j << "] = " << diff << endl;
#endif
                diff -= 2.0*conv[i*W+j];
                if (diff<minval) {
                    minval = diff;
                    x=j, y=i;
                }
            }
        }
        cout << y+1 << ' ' << x+1 << endl;
    }
    return 0;
}