快速数论变换(Number Theoretic Transform)

By Long Luo

之前的文章 快速傅里叶变换(FFT)算法快速傅里叶变换(FFT)算法的实现及优化 详细介绍了 FFT 的具体实现及其实现。

FFT 优点很多,但缺点也很明显。例如单位复根的实部和虚部分别是一个正弦及余弦函数,有大量浮点数计算,计算量很大,而且浮点数运算产生的误差会比较大。

如果我们操作的对象都是整数的话,其实数学家已经发现了一个更好的方法:快速数论变换 (Number Theoretic Transform) 1

快速数论变换(NTT)

FFT 的本质是什么?

  • 将卷积操作变成了乘法操作。

是什么让 FFT 做到了 O(nlogn) 的复杂度?

  • 单位复根

那有没有什么其他的东西也拥有单位根的这些性质呢?

答案当然是有的,原根2 就具有和单位根一样的性质。

所以快速数论变换 NTT 就是以数论为基础的具有循环卷积性质的,用有限域上的单位根来取代复平面上的单位根的 FFT

原根

仿照单位复数根的形式,也将原根的取值看成一个圆,不过这个圆上只有有限个点,每个点表达的是模数的剩余系中的值。

FFT 中,我们总共用到了单位复根的这些性质:

  1. n 个单位复根互不相同;
  2. ωnk=ω2n2k
  3. ωnk=ωnk+n/2
  4. ωna×ωnb=ωna+b

我们发现原根具有和单位复根一样的性质,简单证明3

n 为大于 12 的幂,p 为素数且 n(p1)gp 的一个原根。

我们设 gn=gp1n

  1. gnn=gnp1n=gp1

  2. gnn2=gp12

  3. ganak=gak(p1)an=gk(p1)n=gnk

显然

  1. gnn1(modp)

  2. gnn21(modp)

  3. (gnk+n2)2=gn2k+ngn2k(modp)

证毕。

所以将 gnkgnk+n2 带入本质上和将 ωnkωnk+n2 带入的操作无异。

利用 Vandermonde 矩阵性质,类似 NTT 那样,我们可以从 NTT 变换得到逆变换 INTT 变换,设 x(n) 为整数序列,则有:

NTT : X(m)=i=0Nx(n)amn(modM)

INTT : X(m)=N1i=0Nx(n)amn(modM)

这里 N1amn(modM) 为模意义下的乘法逆元。

当然, NTT 也是有自己的缺点的:比如不能够处理小数的情况,以及不能够处理没有模数的情况。对于模数的选取也有一定的要求,首先是必须要有原根,其次是必须要是 2 的较高幂次的倍数。

NTT 实现

通过上面的分析,开始写代码吧:-)

NTT 也有递归版(Recursion)和迭代版(Iteration) 2 种实现:

递归版(Recursion)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
const long long G = 3;
const long long G_INV = 332748118;
const long long MOD = 998244353;

vector<int> rev;

long long quickPower(long long a, long long b) {
long long res = 1;

while (b > 0) {
if (b & 1) {
res = (res * a) % MOD;
}

a = (a * a) % MOD;
b >>= 1;
}

return res % MOD;
}

void ntt(vector<long long> &a, bool invert) {
int n = a.size();

if (n == 1) {
return;
}

vector<long long> Pe(n / 2), Po(n / 2);

for (int i = 0; 2 * i < n; i++) {
Pe[i] = a[2 * i];
Po[i] = a[2 * i + 1];
}

ntt(Pe, invert);
ntt(Po, invert);

long long wn = quickPower(invert ? G_INV : G, (MOD - 1) / n);
long long w = 1;

for (int i = 0; i < n / 2; i++) {
a[i] = Pe[i] + w * Po[i] % MOD;
a[i] = (a[i] % MOD + MOD) % MOD;
a[i + n / 2] = Pe[i] - w * Po[i] % MOD;
a[i + n / 2] = (a[i + n / 2] % MOD + MOD) % MOD;
w = w * wn % MOD;
}
}

迭代版(Iteration)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
public:
static const long long MOD = 998244353;
static const long long G = 3;
static const int G_INV = 332748118;
vector<int> rev;

long long quickPower(long long a, long long b) {
long long res = 1;

while (b > 0) {
if (b & 1) {
res = (res * a) % MOD;
}

a = (a * a) % MOD;
b >>= 1;
}

return res % MOD;
}

void ntt(vector<long long> &a, bool invert = false) {
int n = a.size();

for (int i = 0; i < n; i++) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}

for (int len = 2; len <= n; len <<= 1) {
long long wlen = quickPower(invert ? G_INV : G, (MOD - 1) / len);

for (int i = 0; i < n; i += len) {
long long w = 1;
for (int j = 0; j < len / 2; j++) {
long long u = a[i + j];
long long v = (w * a[i + j + len / 2]) % MOD;
a[i + j] = (u + v) % MOD;
a[i + j + len / 2] = (MOD + u - v) % MOD;
w = (w * wlen) % MOD;
}
}
}

if (invert) {
long long inver = quickPower(n, MOD - 2);
for (int i = 0; i < n; i++) {
a[i] = (long long) a[i] * inver % MOD;
}
}
}

复杂度分析

  • 时间复杂度O((m+n)log(m+n))
  • 空间复杂度O(m+n)

参考资料


  1. Wiki: Number Theoretic Transform↩︎

  2. 快速数论变换(NTT)及蝴蝶操作构造详解↩︎

  3. 学习笔记 - NTT(快速数论变换)↩︎