acm模板-多项式

发布于 2022-04-14  1883 次阅读


多项式

多项式类

inline int add(int x, int y) { return (x + y >= MOD) ? x + y - MOD : x + y; }

inline int dec(int x, int y) { return (x - y < 0) ? x - y + MOD : x - y; }

inline void inc(int &x, int y) { x = add(x, y); }

inline void rec(int &x, int y) { x = dec(x, y); }

int Wn[N << 1], lg[N], r[N], tot;

inline void init_poly(int n) {
  int p = 1;
  while (p <= n) p <<= 1;
  for (int i = 2; i <= p; ++i) lg[i] = lg[i >> 1] + 1;
  for (int i = 1; i < p; i <<= 1) {
    int wn = ksm(3, (MOD - 1) / (i << 1));
    Wn[++tot] = 1;
    for (int j = 1; j < i; ++j) ++tot, Wn[tot] = 1ll * Wn[tot - 1] * wn % MOD;
  }
}

inline void init(int lim) {
  int len = lg[lim] - 1;
  for (int i = 0; i < lim; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << len);
}

int iv[N], tp;

inline void init_inv(int n) {
  if (!tp) tp = 2, iv[0] = iv[1] = 1;
  for (; tp <= n; ++tp) iv[tp] = 1ll * (MOD - MOD / tp) * iv[MOD % tp] % MOD;
}

int I;
mt19937 rnd(time(nullptr));

struct pt {
  int a, b;

  explicit pt(int _a = 0, int _b = 0) { a = _a, b = _b; }
};

inline pt operator*(pt x, pt y) {
  pt ret;
  ret.a = add(1ll * x.a * y.a % MOD, 1ll * x.b * y.b % MOD * I % MOD);
  ret.b = add(1ll * x.a * y.b % MOD, 1ll * x.b * y.a % MOD);
  return ret;
}

inline bool check(int x) { return ksm(x, (MOD - 1) / 2) == 1; }

inline long random() { return rnd() % MOD; }

inline pt qpow(pt a, int b) {
  pt ret = pt(1, 0);
  for (; b; a = a * a, b >>= 1)
    if (b & 1) ret = ret * a;
  return ret;
}

inline int cipolla(int n) {
  if (!check(n)) return 0;
  int a = random();
  while (!a || check(dec(1ll * a * a % MOD, n))) a = random();
  I = dec(1ll * a * a % MOD, n);
  int ans = qpow(pt(a, 1), (MOD + 1) / 2).a;
  return min(ans, (int)MOD - ans);
}

struct poly {
  vector<int> v;

  inline poly(int w = 0) : v(1) { v[0] = w; }

  inline poly(vector<int> w) : v(move(w)) {}

  inline int operator[](int x) const { return x >= v.size() ? 0 : v[x]; }

  inline int &operator[](int x) {
    if (x >= v.size()) v.resize(x + 1);
    return v[x];
  }

  inline int size() { return v.size(); }

  inline void resize(int x) { v.resize(x); }

  inline poly slice(int len) const {
    if (len <= v.size()) return vector<int>(v.begin(), v.begin() + len);
    vector<int> ret(v);
    ret.resize(len);
    return ret;
  }

  inline poly operator*(const int &x) const {
    poly ret(v);
    for (int i = 0; i < v.size(); ++i) ret[i] = 1ll * ret[i] * x % MOD;
    return ret;
  }

  inline poly operator-() const {
    poly ret(v);
    for (int i = 0; i < v.size(); ++i) ret[i] = dec(0, ret[i]);
    return ret;
  }

  inline poly operator*(const poly &g) const;

  inline poly operator/(const poly &g) const;

  inline poly operator%(const poly &g) const;

  inline poly der() const {
    vector<int> ret(v);
    for (int i = 0; i < ret.size() - 1; ++i)
      ret[i] = 1ll * ret[i + 1] * (i + 1) % MOD;
    ret.pop_back();
    return ret;
  }

  inline poly jifen() const {
    vector<int> ret(v);
    init_inv(ret.size());
    ret.push_back(0);
    for (int i = ret.size() - 1; i; --i)
      ret[i] = 1ll * ret[i - 1] * iv[i] % MOD;
    ret[0] = 0;
    return ret;
  }

  inline poly rev() const {
    vector<int> ret(v);
    reverse(ret.begin(), ret.end());
    return ret;
  }

  inline poly inv() const;

  inline poly div(const poly &FF) const;

  inline poly ln() const;

  inline poly exp() const;

  inline poly pow(int k) const;

  inline poly sqrt() const;

  inline poly mulT(const poly &g, int siz, int tp) const;
};

inline poly operator+(const poly &x, const poly &y) {
  vector<int> v(max(x.v.size(), y.v.size()));
  for (int i = 0; i < v.size(); ++i) v[i] = add(x[i], y[i]);
  return v;
}

inline poly operator-(const poly &x, const poly &y) {
  vector<int> v(max(x.v.size(), y.v.size()));
  for (int i = 0; i < v.size(); ++i) v[i] = dec(x[i], y[i]);
  return v;
}

LL fr[N];

inline void NTT(poly &f, int lim, int tp) {
  for (int i = 0; i < lim; ++i) fr[i] = f[r[i]];
  for (int mid = 1; mid < lim; mid <<= 1)
    for (int len = mid << 1, l = 0; l + len - 1 < lim; l += len)
      for (int k = l; k < l + mid; ++k) {
        LL w1 = fr[k], w2 = fr[k + mid] * Wn[mid + k - l] % MOD;
        fr[k] = w1 + w2;
        fr[k + mid] = w1 + MOD - w2;
      }
  for (int i = 0; i < lim; ++i) fr[i] >= MOD ? fr[i] %= MOD : 0;
  if (!tp) {
    reverse(fr + 1, fr + lim);
    int iv = ksm(lim, MOD - 2);
    for (int i = 0; i < lim; ++i) fr[i] = fr[i] * iv % MOD;
  }
  for (int i = 0; i < lim; ++i) f[i] = fr[i];
}

inline poly poly::operator*(const poly &G) const {
  poly f(v), g = G;
  int rec = f.size() + g.size() - 1;
  int len = lg[rec], lim = 1 << (len + 1);
  init(lim);
  NTT(f, lim, 1);
  NTT(g, lim, 1);
  for (int i = 0; i < lim; ++i) f[i] = 1ll * f[i] * g[i] % MOD;
  NTT(f, lim, 0);
  return f.slice(rec);
}

inline poly poly::inv() const {
  poly g, g0, d;
  g[0] = ksm(v[0], MOD - 2);
  for (int lim = 2; (lim >> 1) < v.size(); lim <<= 1) {
    g0 = g;
    d = slice(lim);
    init(lim);
    NTT(g0, lim, 1);
    NTT(d, lim, 1);
    for (int i = 0; i < lim; ++i) d[i] = 1ll * g0[i] * d[i] % MOD;
    NTT(d, lim, 0);
    fill(d.v.begin(), d.v.begin() + (lim >> 1), 0);
    NTT(d, lim, 1);
    for (int i = 0; i < lim; ++i) d[i] = 1ll * d[i] * g0[i] % MOD;
    NTT(d, lim, 0);
    for (int i = lim >> 1; i < lim; ++i) g[i] = dec(g[i], d[i]);
  }
  return g.slice(v.size());
}

inline poly poly::div(const poly &FF) const {
  if (v.size() == 1) return 1ll * v[0] * ksm(FF[0], MOD - 2) % MOD;
  int len = lg[v.size()], lim = 1 << (len + 1), nlim = lim >> 1;
  poly F = FF, G0 = FF.slice(nlim);
  G0 = G0.inv();
  poly H0 = slice(nlim), Q0;

  init(lim);
  NTT(G0, lim, 1);
  NTT(H0, lim, 1);
  for (int i = 0; i < lim; ++i) Q0[i] = 1ll * G0[i] * H0[i] % MOD;
  NTT(Q0, lim, 0);
  Q0.resize(nlim);

  poly ret = Q0;
  NTT(Q0, lim, 1);
  NTT(F, lim, 1);
  for (int i = 0; i < lim; ++i) Q0[i] = 1ll * Q0[i] * F[i] % MOD;
  NTT(Q0, lim, 0);
  fill(Q0.v.begin(), Q0.v.begin() + nlim, 0);
  for (int i = nlim; i < lim && i < v.size(); ++i) Q0[i] = dec(Q0[i], v[i]);
  NTT(Q0, lim, 1);
  for (int i = 0; i < lim; ++i) Q0[i] = 1ll * Q0[i] * G0[i] % MOD;
  NTT(Q0, lim, 0);
  for (int i = nlim; i < lim; ++i) ret[i] = dec(ret[i], Q0[i]);
  return ret.slice(v.size());
}

inline poly poly::ln() const { return der().div(*this).jifen(); }

namespace EXP {
const int logB = 4;
const int B = 16;
poly f, ret, g[30][B];

inline void exp(int lim, int l, int r) {
  if (r - l <= 64) {
    for (int i = l; i < r; ++i) {
      ret[i] = (!i) ? 1 : 1ll * ret[i] * iv[i] % MOD;
      for (int j = i + 1; j < r; ++j)
        inc(ret[j], 1ll * ret[i] * f[j - i] % MOD);
    }
    return;
  }
  int k = (r - l) / B;
  poly bl[B];
  for (auto & i : bl) i.resize(k << 1);
  int len = 1 << (lim - logB + 1);
  for (int i = 0; i < B; ++i) {
    if (i > 0) {
      init(len);
      NTT(bl[i], len, 0);
      for (int j = 0; j < k; ++j) inc(ret[l + i * k + j], bl[i][j + k]);
    }
    exp(lim - logB, l + i * k, l + (i + 1) * k);
    if (i < B - 1) {
      poly H;
      H.resize(k << 1);
      for (int j = 0; j < k; ++j) H[j] = ret[j + l + i * k];
      init(len);
      NTT(H, len, 1);
      for (int j = i + 1; j < B; ++j)
        for (int t = 0; t < (k << 1); ++t)
          inc(bl[j][t], 1ll * H[t] * g[lim][j - i - 1][t] % MOD);
    }
  }
}

inline void init_exp() {
  ret.resize(f.size());
  for (int i = 0; i < f.size(); ++i) f[i] = 1ll * f[i] * i % MOD, ret[i] = 0;
  int mx = lg[f.size()] + 1;
  init_inv(1 << mx);
  for (int lim = mx; lim >= logB; lim -= logB) {
    int bl = 1 << (lim - logB), ll = 1 << (lim - logB + 1);
    init(ll);
    for (int i = 0; i < B - 1; ++i) {
      g[lim][i].resize(bl << 1);
      for (int j = 0; j < (bl << 1); ++j) g[lim][i][j] = f[j + bl * i];
      NTT(g[lim][i], ll, 1);
    }
  }
}
}  // namespace EXP

inline poly poly::exp() const {
  EXP::f = *this;
  EXP::init_exp();
  EXP::exp(lg[v.size()] + 1, 0, 1 << (lg[v.size()] + 1));
  return EXP::ret.slice(v.size());
}

inline poly poly::pow(int k) const { return ((*this).ln() * k).exp(); }

inline poly poly::operator/(const poly &Q) const {
  if (v.size() < Q.v.size()) return 0;
  int p = v.size() - Q.v.size() + 1;
  poly fr = rev(), qr = Q.rev();
  fr.resize(p);
  qr.resize(p);
  return fr.div(qr).rev();
}

inline poly poly::operator%(const poly &Q) const {
  poly F(v);
  return (F - (Q * (F / Q))).slice(Q.v.size() - 1);
}

inline poly poly::sqrt() const {
  poly g, h, gf, F1, F2, F3, f(v);
  g[0] = cipolla(operator[](0));
  h[0] = ksm(g[0], MOD - 2);
  gf[0] = g[0];
  gf[1] = g[0];
  int iv = (MOD + 1) / 2;
  init(1);
  for (int lim = 1; lim < v.size(); lim <<= 1) {
    for (int i = 0; i < lim; ++i) F1[i] = 1ll * gf[i] * gf[i] % MOD;
    NTT(F1, lim, 0);
    for (int i = 0; i < lim; ++i) F1[i + lim] = dec(F1[i], f[i]), F1[i] = 0;

    int nlim = lim << 1;
    init(nlim);
    for (int i = lim; i < nlim; ++i) rec(F1[i], f[i]);
    F2 = h;
    F2.resize(lim);
    NTT(F1, nlim, 1);
    NTT(F2, nlim, 1);

    for (int i = 0; i < nlim; ++i) F1[i] = 1ll * F1[i] * F2[i] % MOD;
    NTT(F1, nlim, 0);
    for (int i = lim; i < nlim; ++i) g[i] = dec(0, 1ll * F1[i] * iv % MOD);

    if (nlim < v.size()) {
      gf = g;
      NTT(gf, nlim, 1);
      for (int i = 0; i < nlim; ++i) F3[i] = 1ll * gf[i] * F2[i] % MOD;
      NTT(F3, nlim, 0);
      fill(F3.v.begin(), F3.v.begin() + lim, 0);
      NTT(F3, nlim, 1);
      for (int i = 0; i < nlim; ++i) F3[i] = 1ll * F3[i] * F2[i] % MOD;
      NTT(F3, nlim, 0);
      for (int i = lim; i < nlim; ++i) rec(h[i], F3[i]);
    }
  }
  return g.slice(v.size());
}

inline poly poly::mulT(const poly &G, int siz, int tp) const {
  poly f(v), g = G;
  if (f.size() <= 100) {
    poly ret;
    for (int i = 0; i < f.size(); ++i) {
      int fr = 0;
      if (tp == 1) fr = max(fr, i - (int)(f.size() - g.size()));
      for (int j = fr; j <= i && j < g.size(); ++j)
        inc(ret[i - j], 1ll * f[i] * g[j] % MOD);
    }
    return ret;
  }
  int len = lg[f.size() * tp], lim = 1 << (len + 1), gg = g.size();
  init(lim);
  reverse(g.v.begin(), g.v.end());
  NTT(f, lim, 1);
  NTT(g, lim, 1);
  for (int i = 0; i < lim; ++i) f[i] = 1ll * f[i] * g[i] % MOD;
  NTT(f, lim, 0);
  return vector<int>(f.v.begin() + gg - 1, f.v.begin() + gg + siz - 1);
}

多项式乘法(FFT)

给定一个 $n$ 次多项式 $F(x)$,和一个 $m$ 次多项式 $G(x)$。请求出 $F(x)$ 和 $G(x)$ 的卷积。

  int n, m;
  cin >> n >> m;
  poly a, b;
  init_poly(n + m + 1);
  a.resize(n + 1), b.resize(m + 1);
  for (int i = 0; i <= n; ++i) cin >> a.v[i];
  for (int i = 0; i <= m; ++i) cin >> b.v[i];
  for (auto i : (a * b).v) cout << i << " ";

多项式求乘法逆

给定一个多项式 $F(x)$ ,请求出一个多项式 $G(x)$, 满足 $F(x) * G(x) \equiv 1 \pmod{x^n}$。

  int n;
  cin >> n;
  poly a;
  init_poly(n);
  a.resize(n);
  for (int i = 0; i < n; ++i) cin >> a.v[i];
  for (auto i : a.inv().v) cout << i << " ";

多项式开根

给定一个$n−1$次多项式$A(x)$,求一个在 $\bmod\ x^n$意义下的多项式$B(x)$,使得$ B^2(x) \equiv A(x) \ (\bmod\ x^n)$。若有多解,请取零次项系数较小的作为答案。

  int n;
  cin >> n;
  poly a;
  init_poly(n);
  a.resize(n);
  for (int i = 0; i < n; ++i) cin >> a.v[i];
  for (auto i : a.sqrt().v) cout << i << " ";

多项式除法

给定一个 $n$ 次多项式 $F(x)$ 和一个 $m$ 次多项式 $G(x)$ ,请求出多项式 $Q(x)$, $R(x)$,满足 $Q(x)$ 次数为 $n-m$,$R(x)$ 次数小于 $m$ 且$F(x) = Q(x) * G(x) + R(x)$

  int n, m;
  cin >> n >> m;
  poly a, b;
  init_poly(n + 1);
  a.resize(n + 1), b.resize(m + 1);
  for (int i = 0; i <= n; ++i) cin >> a.v[i];
  for (int i = 0; i <= m; ++i) cin >> b.v[i];
  for (auto i : (a / b).v) cout << i << " ";
  cout << "\n";
  for (auto i : (a % b).v) cout << i << " ";

多项式对数函数(多项式 ln)

给出 $n-1$ 次多项式 $A(x)$,求一个  $\bmod{x^n}$ 下的多项式 $B(x)$,满足 $B(x) \equiv \ln A(x)$.

  int n;
  cin >> n;
  poly f;
  f.resize(n);
  for (int i = 0; i < n; ++i) cin >> f.v[i];
  init_poly(n);
  for (auto i : f.ln().v) cout << i << " ";

多项式指数函数(多项式 exp)

给出 $n-1$ 次多项式 $A(x)$,求一个 $\bmod{x^n}$ 下的多项式 $B(x)$,满足 $B(x) \equiv \text e^{A(x)}$。

  int n;
  cin >> n;
  poly a;
  init_poly(n);
  a.resize(n);
  for (int i = 0; i < n; ++i) cin >> a.v[i];
  for (auto i : a.exp().v) cout << i << " ";

多项式快速幂

给定一个 $n-1$ 次多项式 $A(x)$,求一个在 $\bmod\ x^n$ 意义下的多项式 $B(x)$,使得 $B(x) \equiv (A(x))^k \ (\bmod\ x^n)$。

当保证$a_0=1$时

  LL n, k = 0;
  string s;
  cin >> n >> s;
  for (auto i : s) {
    k *= 10;
    k += i - '0';
    k %= MOD;
  }
  poly a;
  init_poly(n);
  a.resize(n);
  for (int i = 0; i < n; ++i) cin >> a.v[i];
  for (auto i : a.pow(k).v) cout << i << " ";

当不保证$a_0$时

LL n, k1 = 0, k2 = 0;
string s;
cin >> n >> s;
for (auto i : s) {
  k1 *= 10;
  k1 += i - '0';
  k1 %= MOD;
  k2 *= 10;
  k2 += i - '0';
  k2 %= MOD - 1;
}
poly a;
init_poly(n);
a.resize(n);
LL len = 0;
for (int i = 0; i < n; ++i) cin >> a[i];
for (auto i : a.v) {
  if (!i)
    len++;
  else
    break;
}
for (int i = 0; i < n; ++i) {
  if (i + len < n)
    a[i] = a[i + len];
  else
    break;
}
a.resize(n - len);
LL tmp = ksm(a[0], MOD - 2), tmp1 = a[0];
for (int i = 0; i < n - len; ++i) a[i] = a[i] * tmp % MOD;
a = a.pow(k1);
tmp = ksm(tmp1, k2);
for (int i = 0; i < n - len; ++i) a[i] = a[i] * tmp % MOD;
for (int i = 0; i < n; ++i) {
  if (i < len * k1 || (len && s.length() > 5))
    cout << "0 ";
  else
    cout << a[i - len * k1] << " ";
}

多项式三角函数

首先由 Euler's formula $\left(e^{ix} = \cos{x} + i\sin{x}\right)$ 可以得到 三角函数的另一个表达式:

$$ \sin{x} = \frac{e^{ix} - e^{-ix}}{2i} $$

$$ \cos{x} = \frac{e^{ix} + e^{-ix}}{2} $$

那么代入 $f\left(x\right)$ 就有:

$$ \sin{f\left(x\right)} = \frac{\exp{\left(if\left(x\right)\right)} - \exp{\left(-if\left(x\right)\right)}}{2i} $$

$$ \cos{f\left(x\right)} = \frac{\exp{\left(if\left(x\right)\right)} + \exp{\left(-if\left(x\right)\right)}}{2} $$

注意到我们是在 $\mathbb{Z}_{998244353}$ 上做 NTT,那么相应地,虚数单位 $i$ 应该被换成 $86583718$ 或 $911660635$:$i = \sqrt{-1} \equiv \sqrt{998244352} \equiv 86583718 \equiv 911660635 \pmod{998244353}$。

拉格朗日插值

$ f(x)=\sum_{i=1}^n{yi\prod{j\ne i}{\dfrac {x-x_j}{x_i-x_j}}} $

for (int i = 1; i <= n; i++) {
  s1 = y[i] % MOD;
  s2 = (LL)1;
  for (int j = 1; j <= n; j++)
    if (i != j) s1 = s1 * (k - x[j]) % MOD, s2 = s2 * (x[i] - x[j]) % MOD;
  ans += s1 * inv(s2) % mod;
}
Hello, world!
最后更新于 2022-04-14