Head File

Head File:

  • ModInt
  • Fenwick Tree
  • Iterative Segment Tree
  • Lazy Segment Tree(range add + range sum)
  • DSU
  • LCA
  • Dijkstra
  • 0-1 BFS
  • Sparse Table
  • KMP
  • Z-function
  • Trie
  • Topological Sort
  • Kruskal MST
#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using pii = pair<int, int>;
const ll INF64 = (1LL << 62);
const int INF = 1e9;

/* =========================================================
 * 1) ModInt
 *    - Modular integer template
 *    - Works best when MOD is prime (for inv() / division)
 * ========================================================= */
template <int MOD>
struct ModInt {
    int v;

    ModInt(long long x = 0) {
        if (x < 0) x = x % MOD + MOD;
        if (x >= MOD) x %= MOD;
        v = int(x);
    }

    int val() const { return v; }

    explicit operator int() const { return v; }

    ModInt& operator+=(const ModInt& other) {
        v += other.v;
        if (v >= MOD) v -= MOD;
        return *this;
    }

    ModInt& operator-=(const ModInt& other) {
        v -= other.v;
        if (v < 0) v += MOD;
        return *this;
    }

    ModInt& operator*=(const ModInt& other) {
        v = int((1LL * v * other.v) % MOD);
        return *this;
    }

    static ModInt power(ModInt a, long long e) {
        ModInt r = 1;
        while (e > 0) {
            if (e & 1) r *= a;
            a *= a;
            e >>= 1;
        }
        return r;
    }

    ModInt inv() const {
        // Only valid when MOD is prime
        return power(*this, MOD - 2);
    }

    ModInt& operator/=(const ModInt& other) {
        return (*this) *= other.inv();
    }

    friend ModInt operator+(ModInt a, const ModInt& b) { return a += b; }
    friend ModInt operator-(ModInt a, const ModInt& b) { return a -= b; }
    friend ModInt operator*(ModInt a, const ModInt& b) { return a *= b; }
    friend ModInt operator/(ModInt a, const ModInt& b) { return a /= b; }

    friend bool operator==(const ModInt& a, const ModInt& b) { return a.v == b.v; }
    friend bool operator!=(const ModInt& a, const ModInt& b) { return a.v != b.v; }

    friend ostream& operator<<(ostream& os, const ModInt& x) {
        return os << x.v;
    }

    friend istream& operator>>(istream& is, ModInt& x) {
        long long y;
        is >> y;
        x = ModInt(y);
        return is;
    }
};

/* =========================================================
 * 2) Fenwick Tree / BIT
 *    - Point update
 *    - Prefix sum / range sum
 * ========================================================= */
template <typename T>
struct Fenwick {
    int n;
    vector<T> bit;

    Fenwick(int n = 0) { init(n); }

    void init(int n_) {
        n = n_;
        bit.assign(n + 1, T{});
    }

    // a[idx] += val
    void add(int idx, T val) {
        for (++idx; idx <= n; idx += idx & -idx) bit[idx] += val;
    }

    // Prefix sum: sum of [0..idx]
    T sumPrefix(int idx) const {
        if (idx < 0) return T{};
        T res{};
        for (++idx; idx > 0; idx -= idx & -idx) res += bit[idx];
        return res;
    }

    // Range sum: sum of [l..r]
    T rangeSum(int l, int r) const {
        if (l > r) return T{};
        return sumPrefix(r) - sumPrefix(l - 1);
    }
};

/* =========================================================
 * 3) Iterative Segment Tree
 *    - Point update
 *    - Range query
 *    - Merge function is user-defined
 * ========================================================= */
template <typename T, class F>
struct SegTree {
    int n;
    T ID;
    F merge;
    vector<T> st;

    SegTree() {}

    SegTree(const vector<T>& a, T id, F merge_) : ID(id), merge(merge_) {
        init(a);
    }

    void init(const vector<T>& a) {
        int sz = (int)a.size();
        n = 1;
        while (n < sz) n <<= 1;
        st.assign(2 * n, ID);

        for (int i = 0; i < sz; ++i) st[n + i] = a[i];
        for (int i = n - 1; i >= 1; --i) {
            st[i] = merge(st[i << 1], st[i << 1 | 1]);
        }
    }

    // Set a[pos] = val
    void setVal(int pos, T val) {
        pos += n;
        st[pos] = val;
        for (pos >>= 1; pos; pos >>= 1) {
            st[pos] = merge(st[pos << 1], st[pos << 1 | 1]);
        }
    }

    // Query on [l, r]
    T query(int l, int r) const {
        T left = ID, right = ID;
        l += n;
        r += n + 1;
        while (l < r) {
            if (l & 1) left = merge(left, st[l++]);
            if (r & 1) right = merge(st[--r], right);
            l >>= 1;
            r >>= 1;
        }
        return merge(left, right);
    }
};

/* =========================================================
 * 4) Lazy Segment Tree
 *    - Range add
 *    - Range sum query
 *    - 0-based inclusive ranges
 * ========================================================= */
struct LazySegTree {
    int n;
    vector<ll> tree, lazy;

    LazySegTree() {}

    LazySegTree(const vector<ll>& a) {
        init(a);
    }

    void init(const vector<ll>& a) {
        n = (int)a.size();
        tree.assign(4 * max(1, n), 0);
        lazy.assign(4 * max(1, n), 0);
        if (n > 0) build(1, 0, n - 1, a);
    }

    void build(int node, int l, int r, const vector<ll>& a) {
        if (l == r) {
            tree[node] = a[l];
            return;
        }
        int m = (l + r) >> 1;
        build(node << 1, l, m, a);
        build(node << 1 | 1, m + 1, r, a);
        tree[node] = tree[node << 1] + tree[node << 1 | 1];
    }

    void apply(int node, int l, int r, ll val) {
        tree[node] += val * (r - l + 1);
        lazy[node] += val;
    }

    void push(int node, int l, int r) {
        if (lazy[node] == 0 || l == r) return;
        int m = (l + r) >> 1;
        apply(node << 1, l, m, lazy[node]);
        apply(node << 1 | 1, m + 1, r, lazy[node]);
        lazy[node] = 0;
    }

    // Add val to all elements in [ql, qr]
    void rangeAdd(int ql, int qr, ll val) {
        if (n == 0) return;
        rangeAdd(1, 0, n - 1, ql, qr, val);
    }

    void rangeAdd(int node, int l, int r, int ql, int qr, ll val) {
        if (qr < l || r < ql) return;
        if (ql <= l && r <= qr) {
            apply(node, l, r, val);
            return;
        }
        push(node, l, r);
        int m = (l + r) >> 1;
        rangeAdd(node << 1, l, m, ql, qr, val);
        rangeAdd(node << 1 | 1, m + 1, r, ql, qr, val);
        tree[node] = tree[node << 1] + tree[node << 1 | 1];
    }

    // Query sum on [ql, qr]
    ll rangeSum(int ql, int qr) {
        if (n == 0) return 0;
        return rangeSum(1, 0, n - 1, ql, qr);
    }

    ll rangeSum(int node, int l, int r, int ql, int qr) {
        if (qr < l || r < ql) return 0;
        if (ql <= l && r <= qr) return tree[node];
        push(node, l, r);
        int m = (l + r) >> 1;
        return rangeSum(node << 1, l, m, ql, qr)
             + rangeSum(node << 1 | 1, m + 1, r, ql, qr);
    }
};

/* =========================================================
 * 5) DSU / Union-Find
 * ========================================================= */
struct DSU {
    int n;
    vector<int> p, sz;

    DSU(int n = 0) { init(n); }

    void init(int n_) {
        n = n_;
        p.resize(n);
        sz.assign(n, 1);
        iota(p.begin(), p.end(), 0);
    }

    int find(int x) {
        if (p[x] == x) return x;
        return p[x] = find(p[x]);
    }

    bool unite(int a, int b) {
        a = find(a);
        b = find(b);
        if (a == b) return false;
        if (sz[a] < sz[b]) swap(a, b);
        p[b] = a;
        sz[a] += sz[b];
        return true;
    }

    bool same(int a, int b) {
        return find(a) == find(b);
    }

    int size(int x) {
        return sz[find(x)];
    }
};

/* =========================================================
 * 6) LCA by Binary Lifting
 *    - Recent common ancestor
 *    - Distance on tree
 *    - K-th ancestor
 * ========================================================= */
struct LCA {
    int n, LOG;
    vector<vector<int>> g;
    vector<vector<int>> up;
    vector<int> depth;

    LCA(int n = 0) { init(n); }

    void init(int n_) {
        n = n_;
        g.assign(n, {});
        depth.assign(n, 0);
        LOG = 1;
        while ((1 << LOG) <= max(1, n)) ++LOG;
        up.assign(LOG, vector<int>(n, 0));
    }

    void addEdge(int u, int v) {
        g[u].push_back(v);
        g[v].push_back(u);
    }

    void build(int root = 0) {
        vector<int> parent(n, -1);
        stack<int> st;
        st.push(root);
        parent[root] = root;
        up[0][root] = root;
        depth[root] = 0;

        while (!st.empty()) {
            int u = st.top();
            st.pop();
            for (int v : g[u]) {
                if (v == parent[u]) continue;
                parent[v] = u;
                depth[v] = depth[u] + 1;
                up[0][v] = u;
                st.push(v);
            }
        }

        for (int k = 1; k < LOG; ++k) {
            for (int v = 0; v < n; ++v) {
                up[k][v] = up[k - 1][up[k - 1][v]];
            }
        }
    }

    int jump(int u, int k) const {
        for (int i = 0; i < LOG; ++i) {
            if (k & (1 << i)) u = up[i][u];
        }
        return u;
    }

    int lca(int a, int b) const {
        if (depth[a] < depth[b]) swap(a, b);
        a = jump(a, depth[a] - depth[b]);
        if (a == b) return a;
        for (int k = LOG - 1; k >= 0; --k) {
            if (up[k][a] != up[k][b]) {
                a = up[k][a];
                b = up[k][b];
            }
        }
        return up[0][a];
    }

    int dist(int a, int b) const {
        int c = lca(a, b);
        return depth[a] + depth[b] - 2 * depth[c];
    }
};

/* =========================================================
 * 7) Dijkstra
 *    - Single-source shortest path
 *    - Non-negative edge weights only
 * ========================================================= */
struct Dijkstra {
    int n;
    vector<vector<pair<int, int>>> g;

    Dijkstra(int n = 0) { init(n); }

    void init(int n_) {
        n = n_;
        g.assign(n, {});
    }

    // Set undirected = false for directed graph
    void addEdge(int u, int v, int w, bool undirected = true) {
        g[u].push_back({v, w});
        if (undirected) g[v].push_back({u, w});
    }

    vector<ll> run(int s) const {
        vector<ll> dist(n, INF64);
        priority_queue<pair<ll, int>, vector<pair<ll, int>>, greater<pair<ll, int>>> pq;
        dist[s] = 0;
        pq.push({0, s});

        while (!pq.empty()) {
            auto [d, u] = pq.top();
            pq.pop();
            if (d != dist[u]) continue;

            for (auto [v, w] : g[u]) {
                if (dist[v] > d + w) {
                    dist[v] = d + w;
                    pq.push({dist[v], v});
                }
            }
        }
        return dist;
    }
};

/* =========================================================
 * 8) 0-1 BFS
 *    - Edges must have weight 0 or 1
 *    - Faster than Dijkstra for 0/1 weighted graphs
 * ========================================================= */
struct ZeroOneBFS {
    int n;
    vector<vector<pair<int, int>>> g;

    ZeroOneBFS(int n = 0) { init(n); }

    void init(int n_) {
        n = n_;
        g.assign(n, {});
    }

    // Weight must be 0 or 1
    void addEdge(int u, int v, int w, bool undirected = true) {
        g[u].push_back({v, w});
        if (undirected) g[v].push_back({u, w});
    }

    vector<int> run(int s) const {
        vector<int> dist(n, INF);
        deque<int> dq;
        dist[s] = 0;
        dq.push_front(s);

        while (!dq.empty()) {
            int u = dq.front();
            dq.pop_front();

            for (auto [v, w] : g[u]) {
                if (dist[v] > dist[u] + w) {
                    dist[v] = dist[u] + w;
                    if (w == 0) dq.push_front(v);
                    else dq.push_back(v);
                }
            }
        }
        return dist;
    }
};

/* =========================================================
 * 9) Sparse Table for Range Maximum Query
 *    - Static array only
 *    - Query: O(1)
 * ========================================================= */
struct SparseTable {
    int n, K;
    vector<int> lg;
    vector<vector<int>> st;

    SparseTable() {}

    SparseTable(const vector<int>& a) {
        init(a);
    }

    void init(const vector<int>& a) {
        n = (int)a.size();
        if (n == 0) return;

        K = __lg(n) + 1;
        lg.assign(n + 1, 0);
        for (int i = 2; i <= n; ++i) lg[i] = lg[i / 2] + 1;

        st.assign(K, vector<int>(n));
        st[0] = a;

        for (int k = 1; k < K; ++k) {
            int len = 1 << k;
            int half = len >> 1;
            for (int i = 0; i + len <= n; ++i) {
                st[k][i] = max(st[k - 1][i], st[k - 1][i + half]);
            }
        }
    }

    // Query on [l, r], 0-based inclusive
    int query(int l, int r) const {
        int k = lg[r - l + 1];
        return max(st[k][l], st[k][r - (1 << k) + 1]);
    }
};

/* =========================================================
 * 10) KMP
 *     - Prefix function
 *     - Pattern matching
 * ========================================================= */
vector<int> prefix_function(const string& s) {
    int n = (int)s.size();
    vector<int> pi(n, 0);
    for (int i = 1; i < n; ++i) {
        int j = pi[i - 1];
        while (j > 0 && s[i] != s[j]) j = pi[j - 1];
        if (s[i] == s[j]) ++j;
        pi[i] = j;
    }
    return pi;
}

// Return all starting positions where pattern appears in text
vector<int> kmp_search(const string& text, const string& pattern) {
    if (pattern.empty()) return {};
    vector<int> pi = prefix_function(pattern);
    vector<int> pos;
    int j = 0;

    for (int i = 0; i < (int)text.size(); ++i) {
        while (j > 0 && text[i] != pattern[j]) j = pi[j - 1];
        if (text[i] == pattern[j]) ++j;
        if (j == (int)pattern.size()) {
            pos.push_back(i - (int)pattern.size() + 1);
            j = pi[j - 1];
        }
    }
    return pos;
}

/* =========================================================
 * 11) Z-Function
 *     - z[i] = longest substring starting at i
 *       that matches the prefix of the string
 * ========================================================= */
vector<int> z_function(const string& s) {
    int n = (int)s.size();
    vector<int> z(n, 0);
    int l = 0, r = 0;

    for (int i = 1; i < n; ++i) {
        if (i <= r) z[i] = min(r - i + 1, z[i - l]);
        while (i + z[i] < n && s[z[i]] == s[i + z[i]]) ++z[i];
        if (i + z[i] - 1 > r) {
            l = i;
            r = i + z[i] - 1;
        }
    }
    if (n > 0) z[0] = n;
    return z;
}

/* =========================================================
 * 12) Trie
 *     - Lowercase English letters only: 'a' to 'z'
 *     - Supports insertion and prefix/word counting
 * ========================================================= */
struct Trie {
    struct Node {
        int nxt[26];
        int wordCnt;
        int passCnt;
        Node() {
            fill(nxt, nxt + 26, -1);
            wordCnt = 0;
            passCnt = 0;
        }
    };

    vector<Node> t;

    Trie() {
        t.push_back(Node());
    }

    void insert(const string& s) {
        int u = 0;
        t[u].passCnt++;
        for (char ch : s) {
            int c = ch - 'a';
            if (t[u].nxt[c] == -1) {
                t[u].nxt[c] = (int)t.size();
                t.push_back(Node());
            }
            u = t[u].nxt[c];
            t[u].passCnt++;
        }
        t[u].wordCnt++;
    }

    // Number of exact matches of s
    int countWord(const string& s) const {
        int u = 0;
        for (char ch : s) {
            int c = ch - 'a';
            if (t[u].nxt[c] == -1) return 0;
            u = t[u].nxt[c];
        }
        return t[u].wordCnt;
    }

    // Number of inserted strings with prefix s
    int countPrefix(const string& s) const {
        int u = 0;
        for (char ch : s) {
            int c = ch - 'a';
            if (t[u].nxt[c] == -1) return 0;
            u = t[u].nxt[c];
        }
        return t[u].passCnt;
    }
};

/* =========================================================
 * 13) Topological Sort
 *     - Kahn's algorithm
 *     - Directed acyclic graph only
 *     - Returns empty vector if a cycle exists
 * ========================================================= */
vector<int> topo_sort(int n, const vector<vector<int>>& g) {
    vector<int> indeg(n, 0);
    for (int u = 0; u < n; ++u) {
        for (int v : g[u]) indeg[v]++;
    }

    queue<int> q;
    for (int i = 0; i < n; ++i) {
        if (indeg[i] == 0) q.push(i);
    }

    vector<int> order;
    while (!q.empty()) {
        int u = q.front();
        q.pop();
        order.push_back(u);

        for (int v : g[u]) {
            if (--indeg[v] == 0) q.push(v);
        }
    }

    if ((int)order.size() != n) return {};
    return order;
}

/* =========================================================
 * 14) Kruskal MST
 *     - Minimum spanning tree for an undirected weighted graph
 *     - Requires DSU
 *     - If graph is disconnected, used edges count < n - 1
 * ========================================================= */
struct Edge {
    int u, v;
    ll w;
};

pair<ll, vector<Edge>> kruskal_mst(int n, vector<Edge> edges) {
    sort(edges.begin(), edges.end(), [](const Edge& a, const Edge& b) {
        return a.w < b.w;
    });

    DSU dsu(n);
    ll cost = 0;
    vector<Edge> used;

    for (const auto& e : edges) {
        if (dsu.unite(e.u, e.v)) {
            cost += e.w;
            used.push_back(e);
        }
    }

    return {cost, used};
}

Usage examples

1) ModInt

using mint = ModInt<998244353>;

mint a = 2, b = 3;
cout << a + b << '\n'; // 5
cout << a * b << '\n'; // 6
cout << a / b << '\n'; // 2 * 3^{-1} mod 998244353

2) Fenwick Tree

Fenwick<long long> fw(n);
fw.add(2, 5); // a[2] += 5
cout << fw.rangeSum(1, 4); // sum of [1,4]

3) Iterative Segment Tree

vector<long long> a = {1, 3, 2, 7, 4};
auto merge = [](long long x, long long y) { return x + y; };

SegTree<long long, decltype(merge)> st(a, 0LL, merge);
cout << st.query(1, 3) << '\n'; // sum on [1,3]

st.setVal(2, 10); // a[2] = 10
cout << st.query(0, 4) << '\n';

For maximum query:

auto merge = [](long long x, long long y) { return max(x, y); };
SegTree<long long, decltype(merge)> st(a, -(1LL << 60), merge);
cout << st.query(1, 4) << '\n';

4) Lazy Segment Tree

vector<long long> a = {1, 2, 3, 4, 5};
LazySegTree lst(a);

lst.rangeAdd(1, 3, 10); // add 10 to [1,3]
cout << lst.rangeSum(0, 4); // total sum
cout << lst.rangeSum(2, 2); // single position query

5) DSU

DSU dsu(n);
dsu.unite(u, v);
if (dsu.same(u, v)) cout << "Connected\n";

6) LCA

LCA lca(n);
lca.addEdge(u, v);
lca.build(0);

cout << lca.lca(a, b) << '\n';
cout << lca.dist(a, b) << '\n';
cout << lca.jump(x, 3) << '\n'; // 3rd ancestor of x

7) Dijkstra

Dijkstra g(n);
g.addEdge(u, v, w); // undirected by default
auto dist = g.run(0);
cout << dist[t] << '\n';

For a directed graph:

g.addEdge(u, v, w, false);

8) 0-1 BFS

ZeroOneBFS g(n);
g.addEdge(u, v, 0);
g.addEdge(x, y, 1);

auto dist = g.run(0);
cout << dist[t] << '\n';

9) Sparse Table

vector<int> a = {1, 5, 3, 8, 2};
SparseTable sp(a);

cout << sp.query(1, 3) << '\n'; // max on [1,3]

10) KMP

string text = "ababcabcab";
string pattern = "abc";

vector<int> pos = kmp_search(text, pattern);
// pos contains all starting indices of matches

11) Z-function

string s = "aaaaa";
vector<int> z = z_function(s);

12) Trie

Trie trie;
trie.insert("apple");
trie.insert("app");

cout << trie.countWord("app") << '\n'; // exact matches
cout << trie.countPrefix("app") << '\n'; // words starting with "app"

13) Topological Sort

int n = 5;
vector<vector<int>> g(n);
g[0].push_back(1);
g[1].push_back(2);

vector<int> order = topo_sort(n, g);
if (order.empty()) {
cout << "Cycle detected\n";
} else {
for (int x : order) cout << x << ' ';
cout << '\n';
}

14) Kruskal MST

vector<Edge> edges = {
{0, 1, 4},
{1, 2, 2},
{0, 2, 5}
};

auto [cost, used] = kruskal_mst(3, edges);
cout << cost << '\n';

评论

Leave a Reply