読者です 読者をやめる 読者になる 読者になる

pekempeyのブログ

競技プログラミングに関する話題を書いていきます。

Mo's algorithm

Educational Codeforces Round 6 の F 問題 Xors on Segments の想定解が Mo's algorithm らしいのですが、Mo's algorithm を知らなかったので調べました。 そのついでに記事にしておきます。

このページを参考にしました。

www.hackerearth.com

Mo's algorithm とは

区間クエリ系の問題を解くためのアルゴリズム。次のようなクエリに対して有効。

  • 要素が更新されない
  • クエリの先読みが可能
  • 区間 [L,R] の結果から [L-1, R], [L+1, R], [L, R-1], [L, R+1] の結果が容易に得られる

Mo's algorithm の流れ

まず区間を平方分割し、左端をキーにしてクエリをバケットに入れる。その後、各バケット内で右端をキーにしてソートする。

結局のところ上の操作は次の比較関数でソートすることと同じ。

bool comp(a, b):
    S = sqrt(N)
    if a.L / S != b.L / S:
        return a.L / S < b.L / S
    return a.R < b.R

前処理はこれで終わり。ここからどのようにクエリを処理するのか。

実はしゃくとりっぽく区間の伸ばしたり縮めたりして、現在のクエリ区間から次のクエリ区間まで変化させるだけで良い。

要するに次のような処理を行う。

queries を comp でソート

i, j = 0, -1
for l, r, index in queries:
    while i < l: remove(i++)  // [i, j] -> [i+1, j]
    while i > l: add(--i)     // [i, j] -> [i-1, j]
    while j < r: add(++j)     // [i, j] -> [i, j+1]
    while j > r: remove(j--)  // [i, j] -> [i, j-1]
    ans[index] = answer

これだけで計算量が O((N+Q)√N) になる。

Mo's algorithm の計算量

左端 i と右端 j の振る舞いに着目する。

左端 i の振る舞い

基本的にはバケット内で移動するが、O(√N) 回バケットをまたいだ移動をする。

f:id:pekempey:20160123152707p:plain:w400

バケット内で移動する場合、一回のクエリあたり距離 O(√N) の移動するのでクエリ全体では O(Q√N)。 バケットをまたいだ移動は全体で O(N)。 したがって総移動回数は O(Q√N + N) 回。

右端 j の振る舞い

左端が同じバケットにある限りひたすら右に移動し、 左端のバケットが切り替わるタイミングで距離 O(N) の移動を行う。

f:id:pekempey:20160123153929p:plain:w400

したがって総移動回数は O(N√N) 回。


以上より O((N+Q)√N) であることが分かる。

例題

例題として D-query という問題を取り上げる。これは次のような問題。

  • 整数 a1, a2, ... , an が与えられる。
  • 次のクエリを処理する。
    • 区間 [l, r] に含まれる整数の種類数を出力する。

先ほどの擬似コードの通りに書けば良い。

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a) for (int i = 0; i < (a); i++)
#define rep2(i, a, b) for (int i = (a); i < (b); i++)
#define repr(i, a) for (int i = (b) - 1; i >= 0; i--)
#define repr2(i, a, b) for (int i = (b) - 1; i >= (a); i--)
template<class T1, class T2> bool chmin(T1 &a, T2 b) { return b < a && (a = b, true); }
template<class T1, class T2> bool chmax(T1 &a, T2 b) { return a < b && (a = b, true); }
typedef long long ll;

const int N = 400;

bool comp(tuple<int, int, int> a, tuple<int, int, int> b) {
    if (get<0>(a) / N != get<0>(b) / N) return get<0>(a) / N < get<0>(b) / N;
    return get<1>(a) < get<1>(b);
}

int a[30303];
int num[1000100];
int cnt;

void add(int k) {
    num[a[k]]++;
    if (num[a[k]] == 1) cnt++;
}

void remove(int k) {
    num[a[k]]--;
    if (num[a[k]] == 0) cnt--;
}

int main() {
    int n;
    cin >> n;
    rep(i, n) scanf("%d", &a[i]);

    int q;
    cin >> q;
    vector<tuple<int, int, int>> qs(q);
    rep(i, q) {
        int l, r;
        scanf("%d %d", &l, &r);
        qs[i] = make_tuple(l - 1, r - 1, i);
    }
    sort(qs.begin(), qs.end(), comp);

    vector<int> ans(q);

    int i = 0, j = -1;
    for (auto t : qs) {
        int l, r, index;
        tie(l, r, index) = t;
        while (i < l) remove(i++);
        while (i > l) add(--i);
        while (j < r) add(++j);
        while (j > r) remove(j--);

        ans[index] = cnt;
    }

    rep(i, q) printf("%d\n", ans[i]);

    return 0;
}

練習問題

この実装だと i が j を追い越すので、それが問題になる場合は書き換えてください。