问题描述
输入两个多项式的系数表示(低幂到高幂) A=[a0,a1,...,an−1],B=[b0,b1,...,bm−1],输出这两个多项式乘积的系数表示。
两个多项式的乘积:
A(x)B(x)C(x)=a0+a1x+...+an−1xn−1=b0+b1x+...+bm−1xm−1=A(x)B(x)=c0+c1x+c2x2+...+cm+n−1xm+n−1
其中 ck 由卷积表达式给出:
ck=i=0∑kaibk−i,k=0,1,...,m+n−1
如果直接遍历 A(x),B(x) 做乘法时间复杂度是 O(mn),这在 m,n 数值比较接近且比较大时非常消耗时间,下面考虑一种新的高效的方法——快速傅里叶变换(FFT)。
离散傅里叶变换(DFT) 是来计算多项式在 n 个特殊点(单位根)的值。而 快速傅里叶变换(FFT) 是一种快速有效率的对 DFT 的实现。FFT 加速多项式乘法,其基本思想是将两个多项式的系数表示通过 FFT 转化为特殊点处的点值表示,然后计算两个多项式点值表示的乘积得到原多项式卷积的点值表示,再将多项式卷积的点值表示进行 逆离散傅里叶变换(IDFT) 就得到了乘积多项式的系数表示。
多项式的表示
一个多项式一般有有两种表示方法,系数表示法和点值表示法。系数表示法:
f(x)=a0+a1x+a2x2+...+an−1xn−1,an−1=0
点值表示法:
f(x)={(x0,y0),(x1,y1),...,(xn−1,yn−1)},xi=(xi0,xi1,...,xin−1)
其中 i=j 时 xi=xj。这样一个 n−1 次多项式就由 n 个不同点处的取值唯一确定。证明如下,在一个点处我们有:
(xi0xi1……xin−1)⎝⎜⎜⎜⎜⎜⎛a0a1……an−1⎠⎟⎟⎟⎟⎟⎞=yi
在 n 个点处:
⎝⎜⎜⎜⎜⎜⎛x00x10……xn−10x01x11……xn−11x02x12……xn−12……………x0n−1x1n−1……xn−1n−1⎠⎟⎟⎟⎟⎟⎞⎝⎜⎜⎜⎛a0a1…an−1⎠⎟⎟⎟⎞=⎝⎜⎜⎜⎜⎜⎛y0y1……yn−1⎠⎟⎟⎟⎟⎟⎞
当点互不相同时,对于未知数 ai 而言,其系数矩阵为范德蒙矩阵,必可逆,故有唯一解,这唯一确定了多项式系数也即唯一确定了多项式。
FFT
在复数域内考虑方程 xn=1,由代数学知识我们知道,它有 n 个根,分别为 wnk=e2π⋅nk,k=0,1,2,...,n−1,其中由 Euler 公式,wnk=e2π⋅nk=cos(2π⋅nk)+i⋅sin(2π⋅nk)。
下面来看一下关于 wnk 的重要性质,这些性质可以利用Euler公式经过简单计算轻易得到:
wnkwn0w2n2kwnk+2nwn2k=wnk−1⋅wn1=wnn=1=wnk=wmnmk=−wnk=w2nk
此外还有:
j=0∑n−1wnjk={0,kmodn=0n,kmodn=0
这是因为:
j=0∑n−1ωnjk=1−ωnkωn0(1−(ωnk)n)=1−ωnk1−(ωnn)k,kmodn=0
DFT 就是要求得 f(x)=a0+a1x+a2x2+...+an−1xn−1 在上述 n 个单位根处的取值。如果是朴素的DFT,使用 Horner 方法:
f(x)=a0+x(a1+x(a2+...+xan−1)
时间复杂度将仍是 O(n2)。但是现在是在一些特殊的点(xn=1 的 n 个单位根),利用特殊性可以将时间复杂度降低到 O(nlogn),这就是 FFT。其思想如下,考虑在单位根处的点值表示:
⎩⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎧f(ωn0)f(ωn1)…f(ωnn−1)=a0+a1(ωn0)1+a2(ωn0)2+…+an−1(ωn0)n−1=a0+a1(ωn1)2+a2(ωn1)2+…+an−1(ωn1)n−1=a0+a1(ωnn−1)1+a2(ωnn−1)2+…+an−1(ωnn−1)n−1
现在考察 f(x) ,将奇数项与偶数项分开:
f(x)=a0+a1x+a2x2+...+an−1xn−1=(a0+a2x2+...+an−2xn−2)+x(a1+a3x2+...+an−1xn−2)
令
f1(x)f2(x)=(a0+a2x+...+an−2xn−2/2)=(a1+a3x+...+an−1xn−2/2)
则可得
f(x)=f1(x2)+xf2(x2)
于是
f(ωnk)f(ωnk+2n)=f1(ωn2k)+ωnk⋅f2(ωn2k)=f1(ω2nk)+ωnk⋅f2(ω2nk)=f1(ωn2k+n)+ωnk+2n⋅f2(ωn2k+n)=f(ωnk+2n)=f1(ω2nk)−ωnk⋅f2(ω2nk)
可以看到计算 f1 和 f2 各自只需要 f 规模的一半计算量即可得到。这就厉害了,利用 Divide and Conquer 思想,如果我们已经知道 f1 和 f2 分别在 wn/20,wn/21,wn/22,...,wn/2n/2−1 处的值,可以在常数时间求得 f(x) 的值。所以总的时间复杂度为 T(n)=2T(n/2)+O(n)=O(nlogn)。其中 O(n) 是每次分治求对应单位根花费的时间。
IDFT
将多项式从系数表示转化为点值表示后,如何将其变回来?由
⎝⎜⎜⎜⎜⎜⎛x00x10……xn−10x01x11……xn−11x02x12……xn−12……………x0n−1x1n−1……xn−1n−1⎠⎟⎟⎟⎟⎟⎞⎝⎜⎜⎜⎛a0a1…an−1⎠⎟⎟⎟⎞=⎝⎜⎜⎜⎜⎜⎛y0y1……yn−1⎠⎟⎟⎟⎟⎟⎞
只需要两边同称范德蒙矩阵 X 的逆。对一个普通的范德蒙矩阵求逆,时间复杂度是 O(n3) ,但是现在我们有一个特殊的范德蒙矩阵:
W=⎝⎜⎜⎜⎜⎜⎛(wn0)0(wn1)0……(wnn−1)0(wn0)1(wn1)1……(wnn−1)1(wn0)2(wn1)2……(wnn−1)2……………(wn0)n−1(wn1)n−1……(wnn−1)n−1⎠⎟⎟⎟⎟⎟⎞=⎝⎜⎜⎜⎜⎜⎛11……11wn1×1……wn(n−1)×11wn1×2……wn(n−1)×2……………1wn1×(n−1)……wn(n−1)×(n−1)⎠⎟⎟⎟⎟⎟⎞
则
W−1=n1⎝⎜⎜⎜⎜⎜⎛11……11wn−1×1……wn−(n−1)×11wn−1×2……wn−(n−1)×2……………1wn−1×(n−1)……wn−(n−1)×(n−1)⎠⎟⎟⎟⎟⎟⎞
即原矩阵每个元素的共轭再除以 n,下面给一个计算性证明,设上述两个矩阵的乘积矩阵为 M,则:
Mij=k=0∑n−1ωni⋅k⋅nωn−k⋅j=n1k=0∑n−1ωnk(i−j)={0,i=j1,i=j
可见两矩阵的乘积确为单位阵。这样我们就可以反求出系数表示。整个过程如下图所示:

注意:实际为了处理方便会将多项式项数统一为 2 的幂次个。
代码实现
Python 实现
'''多项式乘法的 FFT 实现'''
import numpy as np
from decimal import Decimal
# 单点处离散傅里叶变换
def singleDFT(xn):
N = xn.shape[0]
j = np.arange(N)
k = j.reshape((N, 1))
w = np.exp(-2j * np.pi * k * j / N)
return np.dot(w, xn)
# 单点处逆离散傅里叶变换
def singleIDFT(xk):
N = xk.shape[0]
j = np.arange(N)
k = j.reshape((N, 1))
w = np.exp(2j * np.pi * k * j / N)
return 1 / N * np.dot(w, xk)
# FFT 过程
def FFT(xn):
N = xn.shape[0]
AX = singleDFT(xn=xn.reshape((N, -1)))
while AX.shape[0] < N:
AX_even = AX[:, :int(AX.shape[1] / 2)] # 偶数项多项式
AX_odd = AX[:, int(AX.shape[1] / 2):] # 奇数项多项式
w = np.exp(-2j * np.pi * np.arange(AX.shape[0]) / AX.shape[0] / 2)
w = w.reshape((w.shape[0], 1))
w_AX_odd = w * AX_odd
AX = np.vstack([AX_even + w_AX_odd, AX_even - w_AX_odd])
return AX.ravel()
# IFFT 过程
def IFFT(xk):
N = xk.shape[0]
AX = N * singleIDFT(xk=xk.reshape((N, -1)))
while AX.shape[0] < N:
AX_even = AX[:, :int(AX.shape[1] / 2)]
AX_odd = AX[:, int(AX.shape[1] / 2):]
w = np.exp(2j * np.pi * np.arange(AX.shape[0]) / AX.shape[0] / 2)
w = w.reshape((w.shape[0], 1))
w_AX_odd = w * AX_odd
AX = np.vstack([AX_even + w_AX_odd, AX_even - w_AX_odd])
return 1 / N * AX.ravel()
# 快速傅里叶变换实现多项式乘法
def FFT_polymul(x1, x2):
'''
传入列表x1, x2: 多项式的系数表示,低幂到高幂(a_0+a_1x+a_2x^2+a_n-1x^n-1)
'''
# 将两个多项式的项数补全为2的幂次
poly = 1
while 2 ** poly < len(x1) + len(x2):
poly += 1
N = 2 ** poly
A = FFT(xn=np.array(x1 + [0] * (N - len(x1))))
B = FFT(xn=np.array(x2 + [0] * (N - len(x2))))
C = A * B
D = IFFT(xk=C)
res = []
for i in range(D.shape[0]):
res.append(round(Decimal(D[i].real)))
return res # 返回系数表示(低幂到高幂)
x1 = [0, 1, 2, 3, 4, 6, 9]
x2 = [5, 6, 7, 8]
r = FFT_polymul(x1=x1, x2=x2)
print(r)
C++ 递归版本实现
#include <bits/stdc++.h>
#include <ccomplex>
using namespace std;
const int N = 300010;
const double PI = acos(-1);
int n, m;
complex<double> a[N], b[N];
int bit, tot; // bit: 总二进制位数, tot: 总项数
void fft(int tot, complex<double>a[], int inv){
if(tot == 1) return; // 多项式只有一个常数项直接返回
complex<double>a1[tot >> 1], a2[tot >> 1];
for(int i = 0; i < tot; i += 2) // 根据下标奇偶性划分
a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1];
fft(tot >> 1, a1, inv); // 递归处理
fft(tot >> 1, a2, inv);
complex<double>wk = {cos(2.0 * PI / tot), inv * sin(2.0 * PI / tot)}; // 单位根
complex<double>w = {1, 0};
for(int i = 0; i < (tot >> 1); i ++, w *= wk){
a[i] = a1[i] + w * a2[i];
a[i + (tot >> 1)] = a1[i] - w * a2[i];
}
}
int main(){
cin >> n >> m;
int x;
for(int i = 0; i <= n; i ++) {
cin >> x;
a[i] = {x, 0};
}
for(int i = 0; i <= m; i ++){
cin >> x;
b[i] = {x, 0};
}
while((1 << bit) < n + m + 1) bit ++;
tot = 1 << bit;
fft(tot, a, 1); // -1 表示 IDFT
fft(tot, b, 1);
for(int i = 0; i < tot; i ++) a[i] = a[i] * b[i];
fft(tot, a, -1);
for(int i = 0; i < n + m + 1; i ++)
cout << (int)(real(a[i]) / tot + 0.5) << " ";
}
C++ 迭代版本实现
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <ccomplex>
using namespace std;
const int N = 300010;
const double PI = acos(-1);
int n, m;
complex<double> a[N], b[N];
int rev[N], bit, tot; // bit: 总二进制位数, tot: 总项数
void fft(complex<double> a[], int inv){ // FFT 过程
for (int i = 0; i < tot; i ++ )
if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int mid = 1; mid < tot; mid <<= 1) {
complex<double> w1 = {cos(PI / mid), inv * sin(PI / mid)};
for (int i = 0; i < tot; i += mid * 2) {
complex<double> wk = {1, 0};
for (int j = 0; j < mid; j ++, wk = wk * w1) {
auto x = a[i + j], y = wk * a[i + j + mid];
a[i + j] = x + y, a[i + j + mid] = x - y;
}
}
}
}
int main(){
cin >> n >> m;
double x;
for (int i = 0; i <= n; i ++ ) {
cin >> x;
a[i] = {x, 0};
}
for (int i = 0; i <= m; i ++ ) {
cin >> x;
b[i] = {x, 0};
}
while ((1 << bit) < n + m + 1) bit ++; // 为了项数是 2 的幂次个
tot = 1 << bit; // 为了项数是 2 的幂次个
for (int i = 0; i < tot; i ++ )
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
fft(a, 1), fft(b, 1);
for (int i = 0; i < tot; i ++ ) a[i] = a[i] * b[i];
fft(a, -1);
for (int i = 0; i < n + m + 1; i ++ )
cout << (int)(real(a[i]) / tot + 0.5) << " "; // 加 0.5 为了整数
return 0;
}
Q.E.D.