math314のブログ

主に競技プログラミング,CTFの結果を載せます

任意modでの畳み込み演算をO(n log(n))で

http://misawa.github.io/other/fast_kitamasa_method.html を見て刺激されたので書いた.

  1. 畳み込み演算?
  2. 任意modでの畳み込み演算
  3. 中で使われてる技術達
  4. c++のコード

コードと原理の説明をちょっとだけ載せています.

コードだけ欲しい人は https://gist.github.com/math314/6a08301b8b75b8172798 をどうぞ. fast_int32mod_convolution を使うとint32の任意のmodが取れます.

実はまだverifyしてない.

畳み込み演算?

畳み込み演算とは,プログラムで表すと

z = [0] * (len(x) + len(y) - 1)
for i in xrange(len(x)):
    for j in xrange(len(y)):
        z[i+j] += x[i] * y[j]

こんなやつです.例えば x = [2,3] , y = [4,5,6] の時, z = [8, 10 + 12, 12 + 15, 18] = [8, 22, 27, 18] となります.

配列x,yのサイズが両方 O(n) とすると,このプログラム(アルゴリズム)の計算量は O(n2) です.

任意modでの畳み込み演算

以下にコードの一部を載せます.コメントを一杯書いたのでわかりやすければ幸い.

// ガーナーのアルゴリズム garnerのアルゴリズム
// x % m[i] == r[i] を満たす,最小の0以上の整数x について, x % mod を求める
// ex)
// x % 5  == 4
// x % 7  == 1
// x % 11 == 2
// x % 13 = ?
// -> x % (5*7*11) = 134 より, 最小のxは 134
// よって, x % 13 = 4 を返す
ll garner(vector<Pii> mr, int mod);

template<int mod, int primitive_root>
class NTT {
public:
    int get_mod() const { return mod; }
    //畳み込み
    vector<ll> convolution(const vector<ll>& a, const vector<ll>& b);
};

//異なるmodのNTTを定義3つ定義
typedef NTT<167772161, 3> NTT_1;
typedef NTT<469762049, 3> NTT_2;
typedef NTT<1224736769, 3> NTT_3;

// 任意のmodで畳み込み演算 O(n log n)
// a.size() + b.size() < 2^23 を仮定している. 8 * 10^6 < 2^23 なので 競技プログラミングでは困らないはず
vector<ll> int32mod_convolution(vector<ll> a, vector<ll> b,int mod){
    for (auto& x : a) x %= mod; // 最初に -mod < x < mod にしておく
    for (auto& x : b) x %= mod;

    // 3種類のmodでNTTを行う
    // このmodは全て素数で,互いに異なる.
    NTT_1 ntt1; NTT_2 ntt2; NTT_3 ntt3;
    auto x = ntt1.convolution(a, b); // x[i] = sum(a[i-j],b[j]) % ntt1.getmod();
    auto y = ntt2.convolution(a, b); // y[i] = sum(a[i-j],b[j]) % ntt2.getmod();
    auto z = ntt3.convolution(a, b); // z[i] = sum(a[i-j],b[j]) % ntt3.getmod();

    //modを取らず畳み込みを行った場合,要素の最大値は 2^22 * (2^31 - 1)^2 となる.
    // ここで,2^22はmax(a.size(),b.size())の最大値
    // (2^31 - 1) は intの最大値,つまりmodの最大値.
    // 1224736769 * 469762049 * 167772161 > 2^22 * (2^31 - 1)^2 なので
    // CRT(中国人剰余定理)やgarnerのアルゴリズムで正しい値を復元出来る.
    // ただし,復元する際に,CRTではlong long超えてしまうため,garnerのアルゴリズムを用いる.
    vector<ll> ret(sz(x));
    vector<Pii> mr(3);
    FOR(i, sz(x)){
        mr[0].first = ntt1.get_mod(), mr[0].second = (int)x[i];
        mr[1].first = ntt2.get_mod(), mr[1].second = (int)y[i];
        mr[2].first = ntt3.get_mod(), mr[2].second = (int)z[i];
        // garnerのアルゴリズムで
        // t % ntt1.get_mod() = x[i],
        // t % ntt2.get_mod() = y[i],
        // t % ntt3.get_mod() = z[i],
        // を満たす,最小の0以上の整数 t について,
        // t % mod を求める.
        ret[i] = garner(mr, mod);
    }

    return ret;
}

中で使われてる技術達

  1. FFT
  2. NTT(FMT)
  3. Garnerのアルゴリズム

FFT

FFTは,畳み込み演算を \(O(n log n)\) で行うすごいやつです. ただし,中で浮動小数点数(double)を使って計算するので誤差死が怖い… そこで NTT(FMT)の出番です.

ところで __float128 を使えば実用上誤差が問題にならないのでは?という疑惑がありますがどうなんでしょう.

NTT(FMT)

それぞれ NTT(Number Theoretic Transform) , FMT(Fast Modulo Transformation) です. 後者の方が分かりやすい名前をしている? 両者が同じものを指しているのかは知りませんが,きっと同じ物でしょう.個人的にNTTの方が好きです.

特殊なmodを使うことで,浮動小数点数を使わずにFFTみたいなのが O(n log n) で出来ます.

Garnerのアルゴリズム

x % m[i] = r[i] という式が大量にあるので,これを満たす最小の,0以上の整数xを求めるアルゴリズムです. ただし, m[i],m[j] (i != j) は全て互いに素とします.

mのサイズが O(m) の時,計算量は O(m2) です.

実はこのアルゴリズムを拡張することで,任意のhogeについて x % hoge が計算出来ます.

http://www.csee.umbc.edu/~lomonaco/s08/441/handouts/Garner-Alg-Example.pdf を見ると分かりやすいです.

例題としては http://yukicoder.me/problems/448 (難しめ) でしょうか.

garner単体の実装を載せておきます.

ll garner(vector<Pii> mr, int mod){
    mr.emplace_back(mod, 0);

    vector<ll> coffs(sz(mr), 1);
    vector<ll> constants(sz(mr), 0);
    FOR(i, sz(mr) - 1){
        // coffs[i] * v + constants[i] == mr[i].second (mod mr[i].first) を解く
        ll v = (mr[i].second - constants[i]) * mod_inv<ll>(coffs[i], mr[i].first) % mr[i].first;
        if (v < 0) v += mr[i].first;

        for (int j = i + 1; j < sz(mr); j++) {
            (constants[j] += coffs[j] * v) %= mr[j].first;
            (coffs[j] *= mr[i].first) %= mr[j].first;
        }
    }

    return constants[sz(mr) - 1];
}

コード