Hedwig's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View on GitHub

:heavy_check_mark: math/mod_comb.cpp

Depends on

Verified with

Code

#pragma once
#include <bits/stdc++.h>
using namespace std;
#include "../template/const.hpp"
#include "./mint.cpp"

// Combination
// mod MODで階乗を計算しておくことでcomb(n,k)などをO(1)
// で計算する. MODはconstな素数
template <const int MOD>
struct Combination {
    using mint = ModInt<MOD>;
    int n;
    vector<mint> fact, ifact, invs;

    Combination(int n) : n(n) {
        fact.resize(n + 1);
        ifact.resize(n + 1);
        fact[0] = fact[1] = 1;
        ifact[0] = ifact[1] = 1;
        for (int i = 2; i <= n; ++i) {
            fact[i] = fact[i - 1] * i;
        }
        ifact[n] = fact[n].inverse();
        for (int i = n; i >= 1; --i) {
            ifact[i - 1] = ifact[i] * i;
        }
    }

    // invs_build
    // ax = 1 mod MODをみたすxをa=1,...,nについて計算する.
    // 計算量: O(n)
    void invs_build() {
        invs.resize(n + 1);
        invs[1] = 1;
        for (int i = 2; i <= n; ++i)
            invs[i] = fact[i] * ifact[i - 1];
    }

    // (n,k)
    // 0 <= k <= nなら nCk を返し, そうでないなら0を返す.
    // 制約: k,n整数
    mint operator()(int n, int k) {
        if (k < 0 || k > n) return 0;
        return fact[n] * ifact[k] * ifact[n - k];
    }

    // npk
    // 0 <= k <= nなら nPk を返し, そうでないなら0を返す.
    // 制約: k,n整数
    mint npk(int n, int k) {
        if (k < 0 || k > n) return 0;
        return fact[n] * ifact[n - k];
    }
};

// nck_nbig
// nが大きい時にnCkを計算する
// 計算量: O(k)
template <const int MOD>
ModInt<MOD> nck_nbig(long long n, int k) {
    using mint = ModInt<MOD>;
    mint ans   = 1;
    for (int i = 0; i < k; ++i)
        ans *= mint(n - i);
    for (int i = 1; i < k + 1; ++i)
        ans *= mint(i).inverse();
    return ans;
}

// modpow
// x^y mod mを計算する
// 計算量: O(logy)
long long modpow(long long x, long long y, long long m) {
    long long ans = 1, tmp = x;
    while (y > 0) {
        if (y & 1) {
            ans = (ans * tmp) % m;
        }
        y >>= 1;
        tmp = (tmp * tmp) % m;
    }
    return ans;
}
#line 2 "math/mod_comb.cpp"
#include <bits/stdc++.h>
using namespace std;
#line 2 "template/const.hpp"
constexpr int INF        = 1000'000'000;
constexpr long long HINF = 4000'000'000'000'000'000;
constexpr long long MOD  = 998244353;
constexpr double EPS     = 1e-6;
constexpr double PI      = 3.14159265358979;
#line 3 "math/mint.cpp"
using namespace std;

template <int MOD>
struct ModInt {
  public:
    long long x;
    ModInt(long long x = 0) : x((x % MOD + MOD) % MOD) {}
    constexpr ModInt &operator+=(const ModInt a) noexcept {
        if ((x += a.x) >= MOD) x -= MOD;
        return *this;
    }
    constexpr ModInt &operator-=(const ModInt a) noexcept {
        if ((x += MOD - a.x) >= MOD) x -= MOD;
        return *this;
    }
    constexpr ModInt &operator*=(const ModInt a) noexcept {
        (x *= a.x) %= MOD;
        return *this;
    }
    constexpr ModInt &operator/=(const ModInt a) noexcept { return *this *= a.inverse(); }

    constexpr ModInt operator+(const ModInt a) const noexcept { return ModInt(*this) += a.x; }
    constexpr ModInt operator-(const ModInt a) const noexcept { return ModInt(*this) -= a.x; }
    constexpr ModInt operator*(const ModInt a) const noexcept { return ModInt(*this) *= a.x; }
    constexpr ModInt operator/(const ModInt a) const noexcept { return ModInt(*this) /= a.x; }

    friend constexpr std::ostream &operator<<(std::ostream &os, const ModInt<MOD> a) noexcept { return os << a.x; }
    friend constexpr std::istream &operator>>(std::istream &is, ModInt<MOD> &a) noexcept {
        is >> a.x;
        a.x = (a.x % MOD + MOD) % MOD;
        return is;
    }

    ModInt inverse() const noexcept { // x ^ (-1)
        long long a = x, b = MOD, p = 1, q = 0;
        while (b) {
            long long d = a / b;
            a -= d * b;
            swap(a, b);
            p -= d * q;
            swap(p, q);
        }
        return ModInt(p);
    }
    ModInt pow(long long N) const noexcept { // x ^ N
        ModInt a = 1;
        ModInt y = this->x;
        while (N) {
            if (N & 1) a *= y;
            y *= y;
            N >>= 1;
        }
        return a;
    }
};

template <typename U, int MOD>
inline ModInt<MOD> operator*(const U &c, const ModInt<MOD> &a) { return {c * a.x}; }

using mint = ModInt<998244353>;
#line 6 "math/mod_comb.cpp"

// Combination
// mod MODで階乗を計算しておくことでcomb(n,k)などをO(1)
// で計算する. MODはconstな素数
template <const int MOD>
struct Combination {
    using mint = ModInt<MOD>;
    int n;
    vector<mint> fact, ifact, invs;

    Combination(int n) : n(n) {
        fact.resize(n + 1);
        ifact.resize(n + 1);
        fact[0] = fact[1] = 1;
        ifact[0] = ifact[1] = 1;
        for (int i = 2; i <= n; ++i) {
            fact[i] = fact[i - 1] * i;
        }
        ifact[n] = fact[n].inverse();
        for (int i = n; i >= 1; --i) {
            ifact[i - 1] = ifact[i] * i;
        }
    }

    // invs_build
    // ax = 1 mod MODをみたすxをa=1,...,nについて計算する.
    // 計算量: O(n)
    void invs_build() {
        invs.resize(n + 1);
        invs[1] = 1;
        for (int i = 2; i <= n; ++i)
            invs[i] = fact[i] * ifact[i - 1];
    }

    // (n,k)
    // 0 <= k <= nなら nCk を返し, そうでないなら0を返す.
    // 制約: k,n整数
    mint operator()(int n, int k) {
        if (k < 0 || k > n) return 0;
        return fact[n] * ifact[k] * ifact[n - k];
    }

    // npk
    // 0 <= k <= nなら nPk を返し, そうでないなら0を返す.
    // 制約: k,n整数
    mint npk(int n, int k) {
        if (k < 0 || k > n) return 0;
        return fact[n] * ifact[n - k];
    }
};

// nck_nbig
// nが大きい時にnCkを計算する
// 計算量: O(k)
template <const int MOD>
ModInt<MOD> nck_nbig(long long n, int k) {
    using mint = ModInt<MOD>;
    mint ans   = 1;
    for (int i = 0; i < k; ++i)
        ans *= mint(n - i);
    for (int i = 1; i < k + 1; ++i)
        ans *= mint(i).inverse();
    return ans;
}

// modpow
// x^y mod mを計算する
// 計算量: O(logy)
long long modpow(long long x, long long y, long long m) {
    long long ans = 1, tmp = x;
    while (y > 0) {
        if (y & 1) {
            ans = (ans * tmp) % m;
        }
        y >>= 1;
        tmp = (tmp * tmp) % m;
    }
    return ans;
}
Back to top page