HLD-树链剖分


求lca.

树上路径+, 子树+.

求lca.

struct HLD {
    std::vector<int> fa, dep, siz, son, top;
    std::vector<std::vector<int>> &gra;
    // 1-index
    HLD(std::vector<std::vector<int>> &gra, int root)
        : fa(gra.size()), dep(gra.size()), siz(gra.size()), son(gra.size()), top(gra.size()), gra(gra) {
        auto dfs1 =  // fa, dep, siz, son
            [&](int u, int pre, auto self) -> void {
            fa[u] = pre, dep[u] = dep[pre] + 1, siz[u] = 1;
            for (auto v : gra[u]) {
                if (v != pre) {
                    self(v, u, self);
                    siz[u] += siz[v];
                    if (siz[v] > siz[son[u]])
                        son[u] = v;
                }
            }
        };
        dfs1(root, 0, dfs1);
        auto dfs2 =  // top
            [&](int u, int t, auto self) -> void {
            top[u] = t;
            if (!son[u])
                return;
            self(son[u], t, self); // 搜重儿子
            for (auto v : gra[u]) {
                if (v != fa[u] && v != son[u])
                    self(v, v, self);  // 搜轻儿子
            }
        };
        dfs2(root, root, dfs2);
    }
    int lca(int u, int v) const {
        while (top[u] != top[v]) {
            if (dep[top[u]] < dep[top[v]])
                v = fa[top[v]];
            else
                u = fa[top[u]];
        }

        return dep[u] < dep[v] ? u : v;
    }
};

树上路径+, 子树+, 路径求和, 子树求和.

template <class info, class tag>
class LSGT {
    std::vector<info> node;
    std::vector<tag> ta;
    int siz;
    void build(int idx, int l, int r) {
        if (l == r)
            return;
        int mid = (l + r) >> 1;
        build(idx << 1, l, mid), build(idx << 1 | 1, mid + 1, r);
        node[idx] = node[idx << 1] + node[idx << 1 | 1];
    }
    template <typename T>
    void build(int idx, int l, int r, const std::vector<T> &vec) {
        if (l == r) {
            node[idx] = vec[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(idx << 1, l, mid, vec), build(idx << 1 | 1, mid + 1, r, vec);
        node[idx] = node[idx << 1] + node[idx << 1 | 1];
    }
    void apply(int idx) {
        if (ta[idx].empty())
            return;
        ta[idx << 1].apply(ta[idx]);
        ta[idx << 1 | 1].apply(ta[idx]);
        node[idx << 1].apply(ta[idx]);
        node[idx << 1 | 1].apply(ta[idx]);
        ta[idx] = {};
    }
    void modify(int idx, int l, int r, int ql, int qr, const tag &add) {
        if (ql <= l && qr >= r) {
            ta[idx].apply(add);
            node[idx].apply(add);
            return;
        }
        apply(idx);
        int mid = (l + r) >> 1;
        if (ql <= mid)
            modify(idx << 1, l, mid, ql, qr, add);
        if (qr > mid)
            modify(idx << 1 | 1, mid + 1, r, ql, qr, add);
        node[idx] = node[idx << 1] + node[idx << 1 | 1];
    }
    info query(int idx, int l, int r, int ql, int qr) {
        if (ql <= l && qr >= r)
            return node[idx];
        apply(idx);
        int mid = (l + r) >> 1;
        if (qr <= mid)
            return query(idx << 1, l, mid, ql, qr);
        else if (ql > mid)
            return query(idx << 1 | 1, mid + 1, r, ql, qr);
        else
            return query(idx << 1, l, mid, ql, qr) + query(idx << 1 | 1, mid + 1, r, ql, qr);
    }

public:
    LSGT(const int size) : node(size << 2), ta(size << 2), siz(size) {
        build(1, 1, siz);
    }
    template <typename T>
    LSGT(const std::vector<T> &vec) : node(vec.size() << 2), ta(vec.size() << 2), siz(vec.size() - 1) {
        build(1, 1, siz, vec);
    }
    void modify(int ql, int qr, const tag &add) {
        modify(1, 1, siz, ql, qr, add);
    }
    info query(int ql, int qr) {
        return query(1, 1, siz, ql, qr);
    }
};
struct tag {
    long long add;
    tag() : add(0) {}
    tag(long long x) : add(x) {}
    bool empty() const {
        return !add;
    }
    void apply(const tag &o) {
        add += o.add;
        add %= Mod;
    }
};
struct info {
    int len;
    long long sum;
    info() : len(1), sum(0) {}
    info(long long x) : len(1), sum(x % Mod) {}
    info(int len, long long sum) : len(len), sum(sum) {}
    info operator+(const info &o) const {
        return info{(len + o.len) % Mod, (sum + o.sum) % Mod};
    }
    void apply(const tag &o) {
        sum += o.add * len;
        sum %= Mod;
    }
};

struct HLD {
    std::vector<int> fa, dep, siz, son, top, id, val;
    LSGT<info, tag> tr;
    // 1-index
    HLD(const std::vector<std::vector<int>> &gra, int root, const std::vector<int> &w)
        : fa(gra.size()), dep(gra.size()), siz(gra.size()), son(gra.size()), top(gra.size()), id(gra.size()), val(gra.size()), tr(1) {
        auto dfs1 =  // fa, dep, siz, son
            [&](int u, int pre, auto self) -> void {
            fa[u] = pre, dep[u] = dep[pre] + 1, siz[u] = 1;
            for (auto v : gra[u]) {
                if (v != pre) {
                    self(v, u, self);
                    siz[u] += siz[v];
                    if (siz[v] > siz[son[u]])
                        son[u] = v;
                }
            }
        };
        dfs1(root, 0, dfs1);
        int idx = 0;
        auto dfs2 =  // top, id, val
            [&](int u, int t, auto self) -> void {
            id[u] = ++idx, val[idx] = w[u];
            top[u] = t;
            if (!son[u])
                return;
            self(son[u], t, self);  // 搜重儿子
            for (auto v : gra[u]) {
                if (v != fa[u] && v != son[u])
                    self(v, v, self);  // 搜轻儿子
            }
        };
        dfs2(root, root, dfs2);
        tr = LSGT<info, tag>(val);
    }
    LL query(int u, int v) {
        LL ans = 0;
        while (top[u] != top[v]) {
            if (dep[top[u]] < dep[top[v]]) {
                swap(u, v);
            }
            ans += tr.query(id[top[u]], id[u]).sum;
            u = fa[top[u]];
        }
        auto [l, r] = minmax(id[u], id[v]);
        return (ans + tr.query(l, r).sum) % Mod;
    }
    LL query(int x) {
        return tr.query(id[x], id[x] + siz[x] - 1).sum;
    }
    void modify(int u, int v, int z) {
        z %= Mod;
        while (top[u] != top[v]) {
            if (dep[top[u]] < dep[top[v]]) {
                swap(u, v);
            }
            tr.modify(id[top[u]], id[u], z);
            u = fa[top[u]];
        }
        auto [l, r] = minmax(id[u], id[v]);
        tr.modify(l, r, z);
    }
    void modify(int x, int z) {
        tr.modify(id[x], id[x] + siz[x] - 1, z % Mod);
    }
};
,

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注