任意modでの畳み込み演算をO(n log(n))で
http://misawa.github.io/other/fast_kitamasa_method.html を見て刺激されたので書いた.
- 畳み込み演算?
- 任意modでの畳み込み演算
- 中で使われてる技術達
- c++のコード
コードと原理の説明をちょっとだけ載せています.
コードだけ欲しい人は https://gist.github.com/math314/6a08301b8b75b8172798 をどうぞ. fast_int32mod_convolution を使うとint32の任意のmodが取れます.
実はまだverifyしてない.
畳み込み演算?
畳み込み演算とは,プログラムで表すと
z = [0] * (len(x) + len(y) - 1) for i in xrange(len(x)): for j in xrange(len(y)): z[i+j] += x[i] * y[j]
こんなやつです.例えば x = [2,3] , y = [4,5,6] の時, z = [8, 10 + 12, 12 + 15, 18] = [8, 22, 27, 18] となります.
配列x,yのサイズが両方 O(n) とすると,このプログラム(アルゴリズム)の計算量は O(n2) です.
任意modでの畳み込み演算
以下にコードの一部を載せます.コメントを一杯書いたのでわかりやすければ幸い.
// ガーナーのアルゴリズム garnerのアルゴリズム // x % m[i] == r[i] を満たす,最小の0以上の整数x について, x % mod を求める // ex) // x % 5 == 4 // x % 7 == 1 // x % 11 == 2 // x % 13 = ? // -> x % (5*7*11) = 134 より, 最小のxは 134 // よって, x % 13 = 4 を返す ll garner(vector<Pii> mr, int mod); template<int mod, int primitive_root> class NTT { public: int get_mod() const { return mod; } //畳み込み vector<ll> convolution(const vector<ll>& a, const vector<ll>& b); }; //異なるmodのNTTを定義3つ定義 typedef NTT<167772161, 3> NTT_1; typedef NTT<469762049, 3> NTT_2; typedef NTT<1224736769, 3> NTT_3; // 任意のmodで畳み込み演算 O(n log n) // a.size() + b.size() < 2^23 を仮定している. 8 * 10^6 < 2^23 なので 競技プログラミングでは困らないはず vector<ll> int32mod_convolution(vector<ll> a, vector<ll> b,int mod){ for (auto& x : a) x %= mod; // 最初に -mod < x < mod にしておく for (auto& x : b) x %= mod; // 3種類のmodでNTTを行う // このmodは全て素数で,互いに異なる. NTT_1 ntt1; NTT_2 ntt2; NTT_3 ntt3; auto x = ntt1.convolution(a, b); // x[i] = sum(a[i-j],b[j]) % ntt1.getmod(); auto y = ntt2.convolution(a, b); // y[i] = sum(a[i-j],b[j]) % ntt2.getmod(); auto z = ntt3.convolution(a, b); // z[i] = sum(a[i-j],b[j]) % ntt3.getmod(); //modを取らず畳み込みを行った場合,要素の最大値は 2^22 * (2^31 - 1)^2 となる. // ここで,2^22はmax(a.size(),b.size())の最大値 // (2^31 - 1) は intの最大値,つまりmodの最大値. // 1224736769 * 469762049 * 167772161 > 2^22 * (2^31 - 1)^2 なので // CRT(中国人剰余定理)やgarnerのアルゴリズムで正しい値を復元出来る. // ただし,復元する際に,CRTではlong long超えてしまうため,garnerのアルゴリズムを用いる. vector<ll> ret(sz(x)); vector<Pii> mr(3); FOR(i, sz(x)){ mr[0].first = ntt1.get_mod(), mr[0].second = (int)x[i]; mr[1].first = ntt2.get_mod(), mr[1].second = (int)y[i]; mr[2].first = ntt3.get_mod(), mr[2].second = (int)z[i]; // garnerのアルゴリズムで // t % ntt1.get_mod() = x[i], // t % ntt2.get_mod() = y[i], // t % ntt3.get_mod() = z[i], // を満たす,最小の0以上の整数 t について, // t % mod を求める. ret[i] = garner(mr, mod); } return ret; }
中で使われてる技術達
FFT
FFTは,畳み込み演算を \(O(n log n)\) で行うすごいやつです. ただし,中で浮動小数点数(double)を使って計算するので誤差死が怖い… そこで NTT(FMT)の出番です.
ところで __float128 を使えば実用上誤差が問題にならないのでは?という疑惑がありますがどうなんでしょう.
NTT(FMT)
それぞれ NTT(Number Theoretic Transform) , FMT(Fast Modulo Transformation) です. 後者の方が分かりやすい名前をしている? 両者が同じものを指しているのかは知りませんが,きっと同じ物でしょう.個人的にNTTの方が好きです.
特殊なmodを使うことで,浮動小数点数を使わずにFFTみたいなのが O(n log n) で出来ます.
Garnerのアルゴリズム
x % m[i] = r[i] という式が大量にあるので,これを満たす最小の,0以上の整数xを求めるアルゴリズムです. ただし, m[i],m[j] (i != j) は全て互いに素とします.
mのサイズが O(m) の時,計算量は O(m2) です.
実はこのアルゴリズムを拡張することで,任意のhogeについて x % hoge が計算出来ます.
http://www.csee.umbc.edu/~lomonaco/s08/441/handouts/Garner-Alg-Example.pdf を見ると分かりやすいです.
例題としては http://yukicoder.me/problems/448 (難しめ) でしょうか.
garner単体の実装を載せておきます.
ll garner(vector<Pii> mr, int mod){ mr.emplace_back(mod, 0); vector<ll> coffs(sz(mr), 1); vector<ll> constants(sz(mr), 0); FOR(i, sz(mr) - 1){ // coffs[i] * v + constants[i] == mr[i].second (mod mr[i].first) を解く ll v = (mr[i].second - constants[i]) * mod_inv<ll>(coffs[i], mr[i].first) % mr[i].first; if (v < 0) v += mr[i].first; for (int j = i + 1; j < sz(mr); j++) { (constants[j] += coffs[j] * v) %= mr[j].first; (coffs[j] *= mr[i].first) %= mr[j].first; } } return constants[sz(mr) - 1]; }