快速傅里叶变换(FFT, Fast Fourier Transform)

2021-10-04   


问题描述

输入两个多项式的系数表示(低幂到高幂) A=[a0,a1,...,an1]A=[a_0,a_1,...,a_{n-1}]B=[b0,b1,...,bm1]B=[b_0,b_1,...,b_{m-1}],输出这两个多项式乘积的系数表示。
两个多项式的乘积:

A(x)=a0+a1x+...+an1xn1B(x)=b0+b1x+...+bm1xm1C(x)=A(x)B(x)=c0+c1x+c2x2+...+cm+n1xm+n1\begin{aligned} A(x)&=a_0+a_1x+...+a_{n-1}x^{n-1}\\ B(x)&=b_0+b_1x+...+b_{m-1}x^{m-1}\\ C(x)&=A(x)B(x)=c_0+c_1x+c_2x^2+...+c_{m+n-1}x^{m+n-1} \end{aligned} \\

其中 ckc_k 由卷积表达式给出:

ck=i=0kaibki,k=0,1,...,m+n1c_k=\sum_{i=0}^k a_ib_{k-i},k=0,1,...,m+n-1 \\

如果直接遍历 A(x),B(x)A(x),B(x) 做乘法时间复杂度是 O(mn)O(mn),这在 m,nm,n 数值比较接近且比较大时非常消耗时间,下面考虑一种新的高效的方法——快速傅里叶变换(FFT)。

离散傅里叶变换(DFT) 是来计算多项式在 nn 个特殊点(单位根)的值。而 快速傅里叶变换(FFT) 是一种快速有效率的对 DFT 的实现。FFT 加速多项式乘法,其基本思想是将两个多项式的系数表示通过 FFT 转化为特殊点处的点值表示,然后计算两个多项式点值表示的乘积得到原多项式卷积的点值表示,再将多项式卷积的点值表示进行 逆离散傅里叶变换(IDFT) 就得到了乘积多项式的系数表示。

多项式的表示

一个多项式一般有有两种表示方法,系数表示法和点值表示法。系数表示法:

f(x)=a0+a1x+a2x2+...+an1xn1,an10f(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1},a_{n-1}\neq 0 \\

点值表示法:

f(x)={(x0,y0),(x1,y1),...,(xn1,yn1)},xi=(xi0,xi1,...,xin1)f(x)=\{(\mathbf{x}_0,y_0),(\mathbf{x}_1,y_1),...,(\mathbf{x}_{n-1},y_{n-1})\},\mathbf{x}_i=(x_i^0,x_i^1,...,x_i^{n-1}) \\

其中 iji\neq jxixjx_i \neq x_j。这样一个 n1n-1 次多项式就由 nn 个不同点处的取值唯一确定。证明如下,在一个点处我们有:

(xi0xi1xin1)(a0a1an1)=yi\left(x_{i}^{0} \quad x_{i}^{1} \quad \ldots \quad \ldots \quad x_{i}^{n-1}\right)\left(\begin{array}{c} a_{0} \\ a_{1} \\ \ldots \\ \ldots \\ a_{n-1} \end{array}\right)=y_{i} \\

nn 个点处:

(x00x01x02x0n1x10x11x12x1n1xn10xn11xn12xn1n1)(a0a1an1)=(y0y1yn1)\left(\begin{array}{ccccc} x_{0}^{0} & x_{0}^{1} & x_{0}^{2} & \ldots & x_{0}^{n-1} \\ x_{1}^{0} & x_{1}^{1} & x_{1}^{2} & \ldots & x_{1}^{n-1} \\ \ldots & \ldots & \ldots & \ldots & \ldots \\ \ldots & \ldots & \ldots & \ldots & \ldots \\ x_{n-1}^{0} & x_{n-1}^{1} & x_{n-1}^{2} & \ldots & x_{n-1}^{n-1} \end{array}\right)\left(\begin{array}{c} a_{0} \\ a_{1} \\ \ldots \\ a_{n-1} \end{array}\right)=\left(\begin{array}{c} y_{0} \\ y_{1} \\ \ldots \\ \ldots \\ y_{n-1} \end{array}\right) \\

当点互不相同时,对于未知数 aia_i 而言,其系数矩阵为范德蒙矩阵,必可逆,故有唯一解,这唯一确定了多项式系数也即唯一确定了多项式。

FFT

在复数域内考虑方程 xn=1x^n=1,由代数学知识我们知道,它有 nn 个根,分别为 wnk=e2πkn,k=0,1,2,...,n1w_n^k=e^{2\pi \cdot \frac{k}{n}},k=0,1,2,...,n-1,其中由 Euler 公式,wnk=e2πkn=cos(2πkn)+isin(2πkn)w_n^k=e^{2\pi \cdot \frac{k}{n}}=cos(2\pi \cdot \frac{k}{n})+i\cdot sin(2\pi \cdot \frac{k}{n})
下面来看一下关于 wnkw_n^k 的重要性质,这些性质可以利用Euler公式经过简单计算轻易得到:

wnk=wnk1wn1wn0=wnn=1w2n2k=wnk=wmnmkwnk+n2=wnkwn2k=wn2k\begin{aligned} w_n^k&=w_n^{k-1}\cdot w_n^1 \\ w_n^0&=w_n^n=1 \\ w_{2n}^{2k}&=w_n^k=w_{mn}^{mk}\\ w_n^{k+\frac{n}{2}}&=-w_n^k \\ w_n^{2k}&=w_{\frac{n}{2}}^k \end{aligned} \\

此外还有:

j=0n1wnjk={0,kmodn0n,kmodn=0 \sum_{j=0}^{n-1}w_n^{jk}=\left\{ \begin{aligned} 0,\quad k \mathop{mod}n \neq 0 \\ n,\quad k \mathop{mod}n = 0 \end{aligned} \right. \\

这是因为:

j=0n1ωnjk=ωn0(1(ωnk)n)1ωnk=1(ωnn)k1ωnk,kmodn0\sum_{j=0}^{n-1} \omega_{n}^{j k}=\frac{\omega_{n}^{0}\left(1-\left(\omega_{n}^{k}\right)^{n}\right)}{1-\omega_{n}^{k}}=\frac{1-\left(\omega_{n}^{n}\right)^{k}}{1-\omega_{n}^{k}},k\mathop{mod}n \neq 0 \\

DFT 就是要求得 f(x)=a0+a1x+a2x2+...+an1xn1f(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1} 在上述 nn 个单位根处的取值。如果是朴素的DFT,使用 Horner 方法:

f(x)=a0+x(a1+x(a2+...+xan1)f(x)=a_0+x(a_1+x(a_2+...+xa_{n-1}) \\

时间复杂度将仍是 O(n2)O(n^2)。但是现在是在一些特殊的点(xn=1x^n=1nn 个单位根),利用特殊性可以将时间复杂度降低到 O(nlogn)O(nlogn),这就是 FFT。其思想如下,考虑在单位根处的点值表示:

{f(ωn0)=a0+a1(ωn0)1+a2(ωn0)2++an1(ωn0)n1f(ωn1)=a0+a1(ωn1)2+a2(ωn1)2++an1(ωn1)n1f(ωnn1)=a0+a1(ωnn1)1+a2(ωnn1)2++an1(ωnn1)n1\left\{ \begin{aligned} f\left(\omega_{n}^{0}\right)&=a_{0}+a_{1}\left(\omega_{n}^{0}\right)^{1}+a_{2}\left(\omega_{n}^{0}\right)^{2}+\ldots+a_{n-1}\left(\omega_{n}^{0}\right)^{n-1} \\ f\left(\omega_{n}^{1}\right)&=a_{0}+a_{1}\left(\omega_{n}^{1}\right)^{2}+a_{2}\left(\omega_{n}^{1}\right)^{2}+\ldots+a_{n-1}\left(\omega_{n}^{1}\right)^{n-1} \\ \ldots \\ f\left(\omega_{n}^{n-1}\right)&=a_{0}+a_{1}\left(\omega_{n}^{n-1}\right)^{1}+a_{2}\left(\omega_{n}^{n-1}\right)^{2}+\ldots+a_{n-1}\left(\omega_{n}^{n-1}\right)^{n-1} \end{aligned} \right. \\

现在考察 f(x)f(x) ,将奇数项与偶数项分开:

f(x)=a0+a1x+a2x2+...+an1xn1=(a0+a2x2+...+an2xn2)+x(a1+a3x2+...+an1xn2)\begin{aligned} f(x)&=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}\\ &=(a_0+a_2x^2+...+a_{n-2}x^{n-2})+\\&x(a_1+a_3x^2+...+a_{n-1}x^{n-2}) \end{aligned}

f1(x)=(a0+a2x+...+an2xn2/2)f2(x)=(a1+a3x+...+an1xn2/2)\begin{aligned} f_1(x)&=(a_0+a_2x+...+a_{n-2}x^{n-2/2})\\ f_2(x)&=(a_1+a_3x+...+a_{n-1}x^{n-2/2}) \end{aligned}

则可得

f(x)=f1(x2)+xf2(x2)\begin{aligned} f(x)=f_1(x^2)+xf_2(x^2) \end{aligned}

于是

f(ωnk)=f1(ωn2k)+ωnkf2(ωn2k)=f1(ωn2k)+ωnkf2(ωn2k)f(ωnk+n2)=f1(ωn2k+n)+ωnk+n2f2(ωn2k+n)=f(ωnk+n2)=f1(ωn2k)ωnkf2(ωn2k)\begin{aligned} f\left(\omega_{n}^{k}\right)&=f_1\left(\omega_{n}^{2 k}\right)+\omega_{n}^{k} \cdot f_2\left(\omega_{n}^{2 k}\right)=f_1\left(\omega_{\frac{n}{2}}^{k}\right)+\omega_{n}^{k} \cdot f_2\left(\omega_{\frac{n}{2}}^{k}\right) \\ f\left(\omega_{n}^{k+\frac{n}{2}}\right)&=f_1\left(\omega_{n}^{2 k+n}\right)+\omega_{n}^{k+\frac{n}{2}} \cdot f_2\left(\omega_{n}^{2 k+n}\right)\\&=f\left(\omega_{n}^{k+\frac{n}{2}}\right)=f_1\left(\omega_{\frac{n}{2}}^{k}\right)-\omega_{n}^{k} \cdot f_2\left(\omega_{\frac{n}{2}}^{k}\right) \end{aligned}

可以看到计算 f1f_1f2f_2 各自只需要 ff 规模的一半计算量即可得到。这就厉害了,利用 Divide and Conquer 思想,如果我们已经知道 f1f_1f2f_2 分别在 wn/20,wn/21,wn/22,...,wn/2n/21w_{n/2}^0,w_{n/2}^1,w_{n/2}^2,...,w_{n/2}^{n/2-1} 处的值,可以在常数时间求得 f(x)f(x) 的值。所以总的时间复杂度为 T(n)=2T(n/2)+O(n)=O(nlogn)T(n)=2T(n/2)+O(n)=O(nlogn)。其中 O(n)O(n) 是每次分治求对应单位根花费的时间。

IDFT

将多项式从系数表示转化为点值表示后,如何将其变回来?由

(x00x01x02x0n1x10x11x12x1n1xn10xn11xn12xn1n1)(a0a1an1)=(y0y1yn1)\left(\begin{array}{ccccc} x_{0}^{0} & x_{0}^{1} & x_{0}^{2} & \ldots & x_{0}^{n-1} \\ x_{1}^{0} & x_{1}^{1} & x_{1}^{2} & \ldots & x_{1}^{n-1} \\ \ldots & \ldots & \ldots & \ldots & \ldots \\ \ldots & \ldots & \ldots & \ldots & \ldots \\ x_{n-1}^{0} & x_{n-1}^{1} & x_{n-1}^{2} & \ldots & x_{n-1}^{n-1} \end{array}\right)\left(\begin{array}{c} a_{0} \\ a_{1} \\ \ldots \\ a_{n-1} \end{array}\right)=\left(\begin{array}{c} y_{0} \\ y_{1} \\ \ldots \\ \ldots \\ y_{n-1} \end{array}\right)

只需要两边同称范德蒙矩阵 XX 的逆。对一个普通的范德蒙矩阵求逆,时间复杂度是 O(n3)O(n^3) ,但是现在我们有一个特殊的范德蒙矩阵:

W=((wn0)0(wn0)1(wn0)2(wn0)n1(wn1)0(wn1)1(wn1)2(wn1)n1(wnn1)0(wnn1)1(wnn1)2(wnn1)n1)=(11111wn1×1wn1×2wn1×(n1)1wn(n1)×1wn(n1)×2wn(n1)×(n1))W=\left(\begin{array}{ccccc} (w_n^{0})^{0} & (w_n^{0})^{1} & (w_n^{0})^{2} & \ldots & (w_n^{0})^{n-1} \\ (w_n^{1})^{0} & (w_n^{1})^{1} & (w_n^{1})^{2} & \ldots & (w_n^{1})^{n-1} \\ \ldots & \ldots & \ldots & \ldots & \ldots \\ \ldots & \ldots & \ldots & \ldots & \ldots \\ (w_n^{n-1})^{0} & (w_n^{n-1})^{1} & (w_n^{n-1})^{2} & \ldots & (w_n^{n-1})^{n-1} \\ \end{array}\right)= \left(\begin{array}{ccccc} 1 & 1 & 1 & \ldots & 1 \\ 1 & w_n^{1\times 1} & w_n^{1\times 2} & \ldots & w_n^{1\times (n-1)} \\ \ldots & \ldots & \ldots & \ldots & \ldots \\ \ldots & \ldots & \ldots & \ldots & \ldots \\ 1 & w_n^{(n-1)\times 1} & w_n^{(n-1)\times 2} & \ldots & w_n^{(n-1)\times (n-1)} \\ \end{array}\right)

W1=1n(11111wn1×1wn1×2wn1×(n1)1wn(n1)×1wn(n1)×2wn(n1)×(n1))W^{-1}=\frac{1}{n}\left(\begin{array}{ccccc} 1 & 1 & 1 & \ldots & 1 \\ 1 & w_n^{-1\times 1} & w_n^{-1\times 2} & \ldots & w_n^{-1\times (n-1)} \\ \ldots & \ldots & \ldots & \ldots & \ldots \\ \ldots & \ldots & \ldots & \ldots & \ldots \\ 1 & w_n^{-(n-1)\times 1} & w_n^{-(n-1)\times 2} & \ldots & w_n^{-(n-1)\times (n-1)} \\ \end{array}\right)

即原矩阵每个元素的共轭再除以 nn,下面给一个计算性证明,设上述两个矩阵的乘积矩阵为 MM,则:

Mij=k=0n1ωnikωnkjn=1nk=0n1ωnk(ij)={0,ij1,i=j\begin{aligned} M_{i j}&=\sum_{k=0}^{n-1} \omega_{n}^{i \cdot k} \cdot \frac{\omega_{n}^{-k \cdot j}}{n} =\frac{1}{n} \sum_{k=0}^{n-1} \omega_{n}^{k(i-j)} =\begin{cases} 0, i\neq j \\ 1, i=j \end{cases} \end{aligned}

可见两矩阵的乘积确为单位阵。这样我们就可以反求出系数表示。整个过程如下图所示:
FFT
注意:实际为了处理方便会将多项式项数统一为 2 的幂次个。

代码实现

Python 实现

'''多项式乘法的 FFT 实现'''
import numpy as np
from decimal import Decimal
# 单点处离散傅里叶变换
def singleDFT(xn):
    N = xn.shape[0]
    j = np.arange(N)
    k = j.reshape((N, 1))
    w = np.exp(-2j * np.pi * k * j / N)
    return np.dot(w, xn)
# 单点处逆离散傅里叶变换
def singleIDFT(xk):
    N = xk.shape[0]
    j = np.arange(N)
    k = j.reshape((N, 1))
    w = np.exp(2j * np.pi * k * j / N)
    return 1 / N * np.dot(w, xk)
# FFT 过程
def FFT(xn):
    N = xn.shape[0]
    AX = singleDFT(xn=xn.reshape((N, -1)))
    while AX.shape[0] < N:
        AX_even = AX[:, :int(AX.shape[1] / 2)]  # 偶数项多项式
        AX_odd = AX[:, int(AX.shape[1] / 2):]  # 奇数项多项式
        w = np.exp(-2j * np.pi * np.arange(AX.shape[0]) / AX.shape[0] / 2)
        w = w.reshape((w.shape[0], 1))
        w_AX_odd = w * AX_odd
        AX = np.vstack([AX_even + w_AX_odd, AX_even - w_AX_odd])
    return AX.ravel()
# IFFT 过程
def IFFT(xk):
    N = xk.shape[0]
    AX = N * singleIDFT(xk=xk.reshape((N, -1)))
    while AX.shape[0] < N:
        AX_even = AX[:, :int(AX.shape[1] / 2)]
        AX_odd = AX[:, int(AX.shape[1] / 2):]
        w = np.exp(2j * np.pi * np.arange(AX.shape[0]) / AX.shape[0] / 2)
        w = w.reshape((w.shape[0], 1))
        w_AX_odd = w * AX_odd
        AX = np.vstack([AX_even + w_AX_odd, AX_even - w_AX_odd])
    return 1 / N * AX.ravel()
# 快速傅里叶变换实现多项式乘法
def FFT_polymul(x1, x2):
    '''
    传入列表x1, x2: 多项式的系数表示,低幂到高幂(a_0+a_1x+a_2x^2+a_n-1x^n-1)
    '''
    # 将两个多项式的项数补全为2的幂次
    poly = 1
    while 2 ** poly < len(x1) + len(x2):
        poly += 1
    N = 2 ** poly

    A = FFT(xn=np.array(x1 + [0] * (N - len(x1))))
    B = FFT(xn=np.array(x2 + [0] * (N - len(x2))))
    C = A * B
    D = IFFT(xk=C)
    res = []
    for i in range(D.shape[0]):
        res.append(round(Decimal(D[i].real)))
    return res # 返回系数表示(低幂到高幂)

x1 = [0, 1, 2, 3, 4, 6, 9]
x2 = [5, 6, 7, 8]
r = FFT_polymul(x1=x1, x2=x2)
print(r)

C++ 递归版本实现

#include <bits/stdc++.h>
#include <ccomplex>
using namespace std;

const int N = 300010;
const double PI = acos(-1);

int n, m;
complex<double> a[N], b[N];
int bit, tot; // bit: 总二进制位数, tot: 总项数

void fft(int tot, complex<double>a[], int inv){
    if(tot == 1) return; // 多项式只有一个常数项直接返回
    complex<double>a1[tot >> 1], a2[tot >> 1];
    for(int i = 0; i < tot; i += 2) // 根据下标奇偶性划分
        a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1];
    fft(tot >> 1, a1, inv); // 递归处理
    fft(tot >> 1, a2, inv);
    complex<double>wk = {cos(2.0 * PI / tot), inv * sin(2.0 * PI / tot)}; // 单位根
    complex<double>w = {1, 0};
    for(int i = 0; i < (tot >> 1); i ++, w *= wk){
        a[i] = a1[i] + w * a2[i];
        a[i + (tot >> 1)] = a1[i] - w * a2[i];
    }
}

int main(){
    cin >> n >> m;
    int x;
    for(int i = 0; i <= n; i ++) {
        cin >> x;
        a[i] = {x, 0};
    }
    for(int i = 0; i <= m; i ++){
        cin >> x;
        b[i] = {x, 0};
    }
    while((1 << bit) < n + m + 1) bit ++;
    tot = 1 << bit;
    fft(tot, a, 1); // -1 表示 IDFT
    fft(tot, b, 1);
    for(int i = 0; i < tot; i ++) a[i] = a[i] * b[i];
    fft(tot, a, -1);
    for(int i = 0; i < n + m + 1; i ++)
        cout << (int)(real(a[i]) / tot + 0.5) << " ";
}

C++ 迭代版本实现

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <ccomplex>

using namespace std;

const int N = 300010;
const double PI = acos(-1);

int n, m;

complex<double> a[N], b[N];

int rev[N], bit, tot; // bit: 总二进制位数, tot: 总项数

void fft(complex<double> a[], int inv){ // FFT 过程
    for (int i = 0; i < tot; i ++ )
        if (i < rev[i]) swap(a[i], a[rev[i]]);
    for (int mid = 1; mid < tot; mid <<= 1) {
        complex<double> w1 = {cos(PI / mid), inv * sin(PI / mid)};
        for (int i = 0; i < tot; i += mid * 2) {
            complex<double> wk = {1, 0};
            for (int j = 0; j < mid; j ++, wk = wk * w1) {
                auto x = a[i + j], y = wk * a[i + j + mid];
                a[i + j] = x + y, a[i + j + mid] = x - y;
            }
        }
    }
}

int main(){
    cin >> n >> m;
    double x;
    for (int i = 0; i <= n; i ++ ) {
        cin >> x;
        a[i] = {x, 0};
    }
    for (int i = 0; i <= m; i ++ ) {
        cin >> x;
        b[i] = {x, 0};
    }
    while ((1 << bit) < n + m + 1) bit ++; // 为了项数是 2 的幂次个
    tot = 1 << bit; // 为了项数是 2 的幂次个
    for (int i = 0; i < tot; i ++ )
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    fft(a, 1), fft(b, 1);
    for (int i = 0; i < tot; i ++ ) a[i] = a[i] * b[i];
    fft(a, -1);
    for (int i = 0; i < n + m + 1; i ++ )
        cout << (int)(real(a[i]) / tot + 0.5) << " "; // 加 0.5 为了整数

    return 0;
}

Q.E.D.


我是星,利剑开刃寒光锋芒的银星,绝不消隐