Codeforces Global Round 2 H - Triples

https://codeforces.com/contest/1119/problem/H

Preliminaries

Hadamard transform of 3bits is almost the same as Fourier transform of 2*2*2.
A polynomial of x,y,z can be represented as 8 coefficients of 1,x,y,z,xy,xz,yz,xyz. Hadamard transform is equavalent to finding 8 values, f(1,1,1),f(1,1,-1),f(1,-1,1),f(1,-1,-1),f(-1,1,1),f(-1,1,-1),f(-1,-1,1),f(-1,-1,-1). We write \(x^{0101}=x_0x_2\), \(x^{1101}=x_0x_2x_3\). We define \(x^{0101} x^{1101} = x^{0101 \oplus 1101}\). In other words, \(x_i^2=1\). If we give such a property to indeterminants, we can say that this is a polynomial ring \(R[x_0,\ldots,x_{K - 1}]\).

Solution

Inputs are denoted by capital letters. Using polynomials, this problem can be restated as finding the coefficient of each monomial of

\[ \prod_{i} (X x^{A[i]} + Y x^{B[i]} + Z x^{C[i]}) \]

Move all \(x^{A[i]}\) to outside, then it becomes

\[ x^{\bigoplus_i A[i]} \prod_i (X + Y x^{A[i] \oplus B[i]} + Z x^{A[i] \oplus C[i]})\]

Later, let \(B[i] \gets A[i] \oplus B[i]\), \(C[i] \gets A[i] \oplus C[i]\). Then, this problem can be reduced to the problem expanding the below.

\[ \prod_i (X + Y x^{B[i]} + Z x^{C[i]})\]

First, we apply Hadamard transform to the above polynomial. For any assignment of x, each factor takes only four values, X+Y+Z,X-Y+Z,X+Y-Z,X-Y-Z. Fix an assignment of \(x\), let \(\alpha, \beta, \gamma, \delta\) be the numbers of X+Y+Z, X-Y+Z, X+Y-Z, X-Y-Z, respectively. The following holds.

\begin{align}
\alpha + \beta + \gamma + \delta &= N,\\
\alpha - \beta + \gamma - \delta &=\sum_i x^{B[i]},\\
\alpha + \beta - \gamma - \delta &=\sum_i x^{C[i]},\\
\alpha - \beta - \gamma + \delta &=\sum_i x^{B[i] \oplus C[i]}.
\end{align}

Thus we can obtain \(\alpha, \beta, \gamma, \delta\) by solving this equations.

#include <stdio.h>
#include <vector>

using namespace std;

const int MOD = 998244353;

struct mint {
    int n;
    mint(int n_ = 0) : n(n_) {}
};

mint operator+(mint a, mint b) { a.n += b.n; if (a.n >= MOD) a.n -= MOD; return a; }
mint operator-(mint a, mint b) { a.n -= b.n; if (a.n < 0) a.n += MOD; return a; }
mint operator*(mint a, mint b) { return (long long)a.n * b.n % MOD; }
mint &operator+=(mint &a, mint b) { return a = a + b; }
mint &operator-=(mint &a, mint b) { return a = a - b; }
mint &operator*=(mint &a, mint b) { return a = a * b; }

mint modpow(mint a, long long b) {
  mint res = 1;
  while (b > 0) {
    if (b & 1) res *= a;
    a *= a;
    b >>= 1;
  }
  return res;
}

mint modinv(mint a) {
  return modpow(a, MOD - 2);
}

mint operator/(mint a, mint b) { return a * modinv(b); }

template<class T>
void H(T *a, int l, int r) {
  if (r - l == 1) return;
  int n = r - l;
  int m = (l + r) / 2;
  H(a, l, m);
  H(a, m, r);
  for (int i = l; i < m; i++) {
    int j = i + n / 2;
    T x = a[i];
    T y = a[j];
    a[i] = x + y;
    a[j] = x - y;
  }
}

mint in() {
  int x;
  scanf("%d", &x);
  return x % MOD;
}

int main() {
  int N, K;
  scanf("%d %d", &N, &K);
  vector<int> A(1 << K);
  vector<int> B(1 << K);
  vector<int> C(1 << K);
  mint x = in(), y = in(), z = in();
  int s = 0;
  for (int i = 0; i < N; i++) {
    int a, b, c;
    scanf("%d %d %d", &a, &b, &c);
    s ^= a;
    b ^= a;
    c ^= a;
    A[b]++;
    B[c]++;
    C[b^c]++;
  }
  H(A.data(), 0, 1<<K);
  H(B.data(), 0, 1<<K);
  H(C.data(), 0, 1<<K);
  vector<mint> ans(1 << K, 1);
  for (int i = 0; i < 1 << K; i++) {
    // x+y+z -> a
    // x-y+z -> b
    // x+y-z -> c
    // x-y-z -> d
    vector<int> D(4);
    // D[0] = a + b + c + d
    // D[1] = a - b + c - d
    // D[2] = a + b - c - d
    // D[3] = a - b - c + d
    D[0] = N;
    D[1] = A[i];
    D[2] = B[i];
    D[3] = C[i];
    H(D.data(), 0, 4);
    mint v = 1;
    v *= modpow(x+y+z, D[0] / 4);
    v *= modpow(x-y+z, D[1] / 4);
    v *= modpow(x+y-z, D[2] / 4);
    v *= modpow(x-y-z, D[3] / 4);
    ans[i] = v;
  }
  H(ans.data(), 0, 1 << K);
  vector<mint> out(1 << K);
  for (int i = 0; i < 1 << K; i++) {
    out[i ^ s] = ans[i] / (1 << K);
  }
  for (int i = 0; i < 1 << K; i++) {
    printf("%d ", out[i].n);
  }
  putchar('\n');
}