By Long Luo
之前的文章 快速傅里叶变换(FFT)算法 和 快速傅里叶变换(FFT)算法的实现及优化 详细介绍了 的具体实现及其实现。
优点很多,但缺点也很明显。例如单位复根的实部和虚部分别是一个正弦及余弦函数,有大量浮点数计算,计算量很大,而且浮点数运算产生的误差会比较大。
如果我们操作的对象都是整数的话,其实数学家已经发现了一个更好的方法:快速数论变换 。
快速数论变换(NTT)
的本质是什么?
是什么让 做到了 的复杂度?
那有没有什么其他的东西也拥有单位根的这些性质呢?
答案当然是有的,原根 就具有和单位根一样的性质。
所以快速数论变换 就是以数论为基础的具有循环卷积性质的,用有限域上的单位根来取代复平面上的单位根的 。
原根
仿照单位复数根的形式,也将原根的取值看成一个圆,不过这个圆上只有有限个点,每个点表达的是模数的剩余系中的值。
在 中,我们总共用到了单位复根的这些性质:
- 个单位复根互不相同;
- ;
- ;
- 。
我们发现原根具有和单位复根一样的性质,简单证明 :
令 为大于 的 的幂, 为素数且 , 为 的一个原根。
我们设 :
显然
证毕。
所以将 和 带入本质上和将 和 带入的操作无异。
利用 Vandermonde 矩阵性质,类似 那样,我们可以从 变换得到逆变换 变换,设 为整数序列,则有:
:
:
这里 , 为模意义下的乘法逆元。
当然, 也是有自己的缺点的:比如不能够处理小数的情况,以及不能够处理没有模数的情况。对于模数的选取也有一定的要求,首先是必须要有原根,其次是必须要是 的较高幂次的倍数。
NTT 实现
通过上面的分析,开始写代码吧:-)
也有递归版(Recursion)和迭代版(Iteration) 种实现:
递归版(Recursion)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| const long long G = 3; const long long G_INV = 332748118; const long long MOD = 998244353;
vector<int> rev;
long long quickPower(long long a, long long b) { long long res = 1;
while (b > 0) { if (b & 1) { res = (res * a) % MOD; }
a = (a * a) % MOD; b >>= 1; }
return res % MOD; }
void ntt(vector<long long> &a, bool invert) { int n = a.size();
if (n == 1) { return; }
vector<long long> Pe(n / 2), Po(n / 2);
for (int i = 0; 2 * i < n; i++) { Pe[i] = a[2 * i]; Po[i] = a[2 * i + 1]; }
ntt(Pe, invert); ntt(Po, invert);
long long wn = quickPower(invert ? G_INV : G, (MOD - 1) / n); long long w = 1;
for (int i = 0; i < n / 2; i++) { a[i] = Pe[i] + w * Po[i] % MOD; a[i] = (a[i] % MOD + MOD) % MOD; a[i + n / 2] = Pe[i] - w * Po[i] % MOD; a[i + n / 2] = (a[i + n / 2] % MOD + MOD) % MOD; w = w * wn % MOD; } }
|
迭代版(Iteration)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| public: static const long long MOD = 998244353; static const long long G = 3; static const int G_INV = 332748118; vector<int> rev;
long long quickPower(long long a, long long b) { long long res = 1;
while (b > 0) { if (b & 1) { res = (res * a) % MOD; }
a = (a * a) % MOD; b >>= 1; }
return res % MOD; }
void ntt(vector<long long> &a, bool invert = false) { int n = a.size();
for (int i = 0; i < n; i++) { if (i < rev[i]) { swap(a[i], a[rev[i]]); } }
for (int len = 2; len <= n; len <<= 1) { long long wlen = quickPower(invert ? G_INV : G, (MOD - 1) / len);
for (int i = 0; i < n; i += len) { long long w = 1; for (int j = 0; j < len / 2; j++) { long long u = a[i + j]; long long v = (w * a[i + j + len / 2]) % MOD; a[i + j] = (u + v) % MOD; a[i + j + len / 2] = (MOD + u - v) % MOD; w = (w * wlen) % MOD; } } }
if (invert) { long long inver = quickPower(n, MOD - 2); for (int i = 0; i < n; i++) { a[i] = (long long) a[i] * inver % MOD; } } }
|
复杂度分析
参考资料
预览: