Range tree (Fenwick tree)

領域木の実装が楽になる小ネタ。決して速いわけではない。

struct RangeTree {
  std::vector<std::vector<int>> dat;

  RangeTree(const std::vector<int> &a) : dat(a.size() + 1) {
    const int n = a.size();
    for (int i = 0; i < n; i++) {
      for (int j = i + 1; j <= n; j += j & -j) dat[j].push_back(a[i]);
      std::sort(dat[i + 1].begin(), dat[i + 1].end());
    }
  }

  int count(int l, int r, int v) {
    int ret = 0;
    for (int k = r; k > 0; k &= k - 1) ret += std::lower_bound(dat[k].begin(), dat[k].end(), v) - dat[k].begin();
    for (int k = l; k > 0; k &= k - 1) ret -= std::lower_bound(dat[k].begin(), dat[k].end(), v) - dat[k].begin();
    return ret;
  }
};

(修正)segtree の方も短くなりそうだったから短くした。一応注意として r=N だとバグる。何かあまり変わらなくて、この記事の存在理由が怪しくなってきた。segtree における [0,k) 型クエリの実装が楽だということも小ネタにしておく…。

const int N = 1 << 18;

struct RangeTreeST {
  std::vector<std::vector<int>> dat;

  RangeTreeST(const std::vector<int> &a) {
    dat.resize(N * 2);
    for (int i = 0; i < a.size(); i++) {
      for (int j = i + N; j >= 1; j >>= 1) dat[j].push_back(a[i]);
    }
    for (int i = 0; i < N * 2; i++) {
      std::sort(dat[i].begin(), dat[i].end());
    }
  }

  int count(int l, int r, int v) {
    int ret = 0;
    for (int k = r + N; k > 1; k >>= 1) if (k & 1) --k, ret += std::lower_bound(dat[k].begin(), dat[k].end(), v) - dat[k].begin();
    for (int k = l + N; k > 1; k >>= 1) if (k & 1) --k, ret -= std::lower_bound(dat[k].begin(), dat[k].end(), v) - dat[k].begin();
    return ret;
  }
};

Easy Queries [wavelet matrix 3d]

https://www.codechef.com/problems/DISTNUM2

参考にしました。

http://min-25.hatenablog.com/entry/2017/09/13/073449

0.79 秒、24 MB。

要素が少ない方の wavelet matrix だけ使用するヒューリスティックは組み込んだ (0.4 sec の改善)。構築を手抜きしてるといいつつ、メモリはきっちり確保するような実装になっているので、メモリ節約にはなってない。

Min_25 さんのコードに比べて何も良くなってないが、とりあえずメモしておく。改善できたら更新する。

#include <iostream>
#include <vector>
#include <algorithm>
#include <cstdio>

using namespace std;

template<int NN> struct bitvec {
	static constexpr int N = (NN + 63) / 64 * 64;
	struct bar {
		unsigned long long mask;
		int sum;
	} dat[N / 64 + 1];
	void set(int k, bool v) { if (v) dat[k / 64].mask |= 1ULL << k % 64; }
	void build() { for (int i = 0; i < N / 64; i++) dat[i + 1].sum = dat[i].sum + __builtin_popcountll(dat[i].mask); }
	int rank(int k) const { return dat[k / 64].sum + __builtin_popcountll(dat[k / 64].mask & (1ULL << k % 64) - 1); }
	int rank(int l, int r) const { return rank(r) - rank(l); }
};

struct foo { int y, z; };

constexpr int N = 1e5;
constexpr int H = 17;

int n;
foo aa[N], bb[N];
bitvec<N> zz[H], yy[H + 1][H];

void build_y(int dd, int l, int r, int d) {
	if (d == H || l == r) return;
	for (int i = l; i < r; i++) yy[dd][d].set(i, ~bb[i].y >> H - 1 - d & 1);
	int m = stable_partition(bb + l, bb + r, [&](foo v) { return ~v.y >> H - 1 - d & 1; }) - bb;
	build_y(dd, l, m, d + 1);
	build_y(dd, m, r, d + 1);
}

void build_z(int l, int r, int d) {
	if (d == H || l == r) return;
	for (int i = l; i < r; i++) zz[d].set(i, ~aa[i].z >> H - 1 - d & 1);
	int m = stable_partition(aa + l, aa + r, [&](foo v) { return ~v.z >> H - 1 - d & 1; }) - aa;
	if (m - l < r - m) {
		copy(aa + l, aa + m, bb + l);
		build_y(d, l, m, 0);
	} else {
		copy(aa + m, aa + r, bb + m);
		build_y(d, m, r, 0);
	}
	build_z(l, m, d + 1);
	build_z(m, r, d + 1);
}

void build() {
	copy(aa, aa + n, bb);
	build_y(H, 0, n, 0);
	build_z(0, n, 0);
	for (int i = 0; i <= H; i++) {
		if (i < H) zz[i].build();
		for (int j = 0; j < H; j++) yy[i][j].build();
	}
}

int count(int dd, int ll, int rr, int l, int r, int k) {
	int ret = 0;
	for (int i = 0; i < H; i++) {
		if (l == r) return ret;
		int mm = yy[dd][i].rank(ll, rr) + ll;
		if (~k >> H - 1 - i & 1) {
			l = yy[dd][i].rank(ll, l) + ll;
			r = yy[dd][i].rank(ll, r) + ll;
			rr = mm;
		} else {
			ret += yy[dd][i].rank(l, r);
			l += yy[dd][i].rank(l, rr);
			r += yy[dd][i].rank(r, rr);
			ll = mm;
		}
	}
	return ret;
}

int kth_element(int l, int r, int y, int k) {
	int ll = 0, rr = n, t = count(H, 0, n, l, r, y);
	for (int i = 0; i < H; i++) {
		int mm = zz[i].rank(ll, rr) + ll;
		int l0 = zz[i].rank(ll, l) + ll, l1 = zz[i].rank(l, rr) + l;
		int r0 = zz[i].rank(ll, r) + ll, r1 = zz[i].rank(r, rr) + r;
		int cnt = mm - ll < rr - mm ? count(i, ll, mm, l0, r0, y) : t - count(i, mm, rr, l1, r1, y);
		if (k < cnt) {
			l = l0; r = r0; rr = mm; t = cnt;
		} else {
			l = l1; r = r1; ll = mm; k -= cnt; t -= cnt;
		}
	}
	return l < r ? aa[l].z : -1;
}

int main() {
	int n, q;
	cin >> n >> q;

	vector<int> a(n);
	for (int i = 0; i < n; i++) {
		scanf("%d", &a[i]);
	}
	vector<int> dic(a);
	sort(dic.begin(), dic.end());
	dic.erase(unique(dic.begin(), dic.end()), dic.end());
	for (int i = 0; i < n; i++) {
		a[i] = lower_bound(dic.begin(), dic.end(), a[i]) - dic.begin();
	}

	vector<int> pos(n, -1);
	for (int i = 0; i < n; i++) {
		aa[i].y = pos[a[i]] + 1;
		aa[i].z = a[i];
		pos[a[i]] = i;
	}
	::n = n;
	build();

	int ans = 0;
	for (int ii = 0; ii < q; ii++) {
		int a, b, c, d, k;
		scanf("%d %d %d %d %d", &a, &b, &c, &d, &k);
		int l = (1LL * a * max(0, ans) + b) % n + 1;
		int r = (1LL * c * max(0, ans) + d) % n + 1;
		int tmp = kth_element(l - 1, r, l, k - 1);
		if (tmp != -1) tmp = dic[tmp];
		ans = tmp;
		printf("%d\n", ans);
	}
}

wavelet matrix を初めて書いた問題なので思い入れが強い。

ACPC2017Day1 F: Steps

http://judge.u-aizu.ac.jp/onlinejudge/cdescription.jsp?cid=ACPC2017Day1&pid=F

dyck 列(正しい括弧列)の数え上げ問題に、深さ上限を与えた問題。dyck 列の総数はカタラン数として知られる。解法の本質はカタラン数の二項係数表現の導出と同じで、エラーが起きた地点から移動方向を反転させるとよい。カタラン数については分かりやすいサイトが見つからなかったため参考リンクは用意していないが、『数学ガール』の 1 巻に書いてあった説明が分かりやすかった覚えがある(読んだの高 1 くらいなので、具体的に何が書かれてたか覚えてないけど)。定番ネタなので探せばいくらでもあるだろう。

たとえば左のような移動経路は、右の移動経路に対応させる。E (エラー地点)から移動方向を反転させている。

        *                 
 *     * G       *        
S-*---*---  <=> S-*-*------     
   E *             E *    
    *                 *   
                       * G'
                        * 

S→Gへの経路のうち、下側エラーが起きるパターン数は、S から G' への移動経路数と一致することが言える。ここまではカタラン数の計算方法として有名だろう。

今回は下側だけでなく上側も制約がつく。上側エラーも同様に計算できるが、下側と上側で二重カウントしてしまう。そのため包除原理的に重複を除去していく。

下側エラー→上側エラーとなるパターン数と上側エラー→下側エラーとなるパターン数もまた、先程と同様のテクニックにより計算できる(二回反転させる)。下側エラー→上側エラー→下側エラーというのも同様に計算できる(三階反転させる)。これらが計算できれば、包除原理的に解ける。下→上→下→上→…というのを永遠に求め続ける必要はなく、m の制約からして有限個である(m/n個程度であり、これは計算量に影響を与える)

計算量はわかりづらいが、O(n+m) になる。競技中には計算量解析が雑で、n≤500 のとき別の解法を取るような O( (n+m) sqrt(n) ) の解法を設計していたが、その必要はなかった。

#include <iostream>
#include <algorithm>
#include <vector>
#include <cstring>

using namespace std;

constexpr int mod = 1e9 + 7;

struct modint {
  int n;
  modint(int n = 0) : n(n) {}
};

modint operator+(modint a, modint b) { return modint((a.n += b.n) >= mod ? a.n - mod : a.n); }
modint operator-(modint a, modint b) { return modint((a.n -= b.n) < 0 ? a.n + mod : a.n); }
modint operator*(modint a, modint b) { return modint(1LL * a.n * b.n % mod); }
modint &operator+=(modint &a, modint b) { return a = a + b; }
modint &operator-=(modint &a, modint b) { return a = a - b; }
modint &operator*=(modint &a, modint b) { return a = a * b; }

modint fact[505050];
modint ifact[505050];
modint inv[505050];

modint C(int n, int r) {
  if (n < 0 || r < 0 || n < r) return 0;
  return fact[n] * ifact[r] * ifact[n - r];
}

modint g(int w, int y) {
  y = abs(y);
  if (w < y) return 0;
  if ((w - y) % 2 != 0) return 0;
  return C(w, (w - y) / 2);
}

int mirror(int y, int u) {
  return -(y - u) + u;
}

modint f(int w, int h, int y) {
  modint ret;
  ret += g(w, y);
  int y1 = y;
  int y2 = y;
  int d1 = -1;
  int d2 = h + 1;
  for (int ii = 0; min(-d1, d2) <= w; ii++) {
    y1 = mirror(y1, d1);
    y2 = mirror(y2, d2);
    if (ii & 1) {
      ret += g(w, y1);
      ret += g(w, y2);
    } else {
      ret -= g(w, y1);
      ret -= g(w, y2);
    }
    d1 -= h + 2;
    d2 += h + 2;
  }
  return ret;
}

int main() {
  fact[0] = 1;
  ifact[0] = 1;
  inv[1] = 1;
  for (int i = 2; i < 505050; i++) {
    inv[i] = inv[mod % i] * (mod - mod / i);
  }
  for (int i = 1; i < 505050; i++) {
    fact[i] = i * fact[i - 1];
    ifact[i] = inv[i] * ifact[i - 1];
  }
  int n, m;
  cin >> n >> m;

  modint ans = 0;
  for (int i = 0; i <= n - 1; i++) {
    ans += f(m, n - 1, i);
  }
  for (int i = 0; i <= n - 2; i++) {
    ans -= f(m, n - 2, i);
  }
  cout << ans.n << endl;
}

そういえば通り数って聞き馴染みがない(twitter 上でしか見ない)んだけど、よく使われる言葉なのだろうか?場合の数という表現は見覚えがある。