树链剖分


对于树的快速算法.

pre

重儿子: 子树节点最多的儿子. 轻儿子: 非重儿子. 重边: 重儿子 \(to\) 父节点. 重链: 由多条重边连接而成的路径.

  • 整棵树被剖分成若干重链.
  • 轻儿子一定是每条重链的顶点.
  • 任意一条路径被切分为不超过 \(log{n}\) 条重链.

求 \(lca\)

定义数组 \(fa, dep, siz, son, top\) 分别表示父亲, 深度, 子树节点数, 重儿子, 所在重链顶点.

先跑一边 \(dfs\), 易得 \(fa, dep, siz, son\).

再跑一遍得到 \(top\).

求 \(lca\) 时, 若两点不在同一重链, 则将重链头节点更深的一个点跳到头节点的父节点, 重复这一过程.

最后得到两个点在用一条重链上, 那么显然 \(lca\) 就是更浅的那个点.

因为任意一条路径被切分为不超过 \(log{n}\) 条重链, 所以算法复杂度为 \(O(log{n})\).

P3379 【模板】最近公共祖先(LCA)

HLD

void solve() {
    int n, m, s;
    cin >> n >> m >> s;
    vector<vector<int>> gra(n + 1);
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        gra[u].push_back(v);
        gra[v].push_back(u);
    }
    HLD hld(gra, s);
    while (m--) {
        int u, v;
        cin >> u >> v;
        cout << hld.lca(u, v) << '\n';
    }
}

对路径操作

P3384 【模板】重链剖分/树链剖分

对于子树操作易用线段树区间加区间求和维护.

现在考虑操作1, 2.

很显然我们能将路径分为 \(log{n}\) 条链, 对这些链用线段树区间加区间求和即可.

LSGT

int Mod;

struct tag {
    long long add;
    tag() : add(0) {}
    tag(long long x) : add(x) {}
    bool empty() const {
        return !add;
    }
    void apply(const tag &o) {
        add += o.add;
        add %= Mod;
    }
};
struct info {
    int len;
    long long sum;
    info() : len(1), sum(0) {}
    info(long long x) : len(1), sum(x % Mod) {}
    info(int len, long long sum) : len(len), sum(sum) {}
    info operator+(const info &o) const {
        return info{(len + o.len) % Mod, (sum + o.sum) % Mod};
    }
    void apply(const tag &o) {
        sum += o.add * len;
        sum %= Mod;
    }
};

struct HLD {
    std::vector<int> fa, dep, siz, son, top, id, val;
    LSGT<info, tag> tr;
    // 1-index
    HLD(const std::vector<std::vector<int>> &gra, int root, const std::vector<int> &w)
        : fa(gra.size()), dep(gra.size()), siz(gra.size()), son(gra.size()), top(gra.size()), id(gra.size()), val(gra.size()), tr(1) {
        auto dfs1 =  // fa, dep, siz, son
            [&](int u, int pre, auto self) -> void {
            fa[u] = pre, dep[u] = dep[pre] + 1, siz[u] = 1;
            for (auto v : gra[u]) {
                if (v != pre) {
                    self(v, u, self);
                    siz[u] += siz[v];
                    if (siz[v] > siz[son[u]])
                        son[u] = v;
                }
            }
        };
        dfs1(root, 0, dfs1);
        int idx = 0;
        auto dfs2 =  // top, id, val
            [&](int u, int t, auto self) -> void {
            id[u] = ++idx, val[idx] = w[u];
            top[u] = t;
            if (!son[u])
                return;
            self(son[u], t, self);  // 搜重儿子
            for (auto v : gra[u]) {
                if (v != fa[u] && v != son[u])
                    self(v, v, self);  // 搜轻儿子
            }
        };
        dfs2(root, root, dfs2);
        tr = LSGT<info, tag>(val);
    }
    LL query(int u, int v) {
        LL ans = 0;
        while (top[u] != top[v]) {
            if (dep[top[u]] < dep[top[v]]) {
                swap(u, v);
            }
            ans += tr.query(id[top[u]], id[u]).sum;
            u = fa[top[u]];
        }
        auto [l, r] = minmax(id[u], id[v]);
        return (ans + tr.query(l, r).sum) % Mod;
    }
    LL query(int x) {
        return tr.query(id[x], id[x] + siz[x] - 1).sum;
    }
    void modify(int u, int v, int z) {
        z %= Mod;
        while (top[u] != top[v]) {
            if (dep[top[u]] < dep[top[v]]) {
                swap(u, v);
            }
            tr.modify(id[top[u]], id[u], z);
            u = fa[top[u]];
        }
        auto [l, r] = minmax(id[u], id[v]);
        tr.modify(l, r, z);
    }
    void modify(int x, int z) {
        tr.modify(id[x], id[x] + siz[x] - 1, z % Mod);
    }
};

void solve() {
    int n, m, r;
    cin >> n >> m >> r >> Mod;
    vector<int> w(n + 1);
    for (int i = 1; i <= n; i++) cin >> w[i];
    vector<vector<int>> gra(n + 1);
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        gra[u].push_back(v);
        gra[v].push_back(u);
    }
    HLD hld(gra, r, w);
    while (m--) {
        int op;
        cin >> op;
        if (op == 1) {
            int x, y, z;
            cin >> x >> y >> z;
            hld.modify(x, y, z);
        } else if (op == 2) {
            int x, y;
            cin >> x >> y;
            cout << hld.query(x, y) << '\n';
        } else if (op == 3) {
            int x, z;
            cin >> x >> z;
            hld.modify(x, z);
        } else {
            int x;
            cin >> x;
            cout << hld.query(x) << '\n';
        }
    }
}

求每个子树的信息

CF600E. Lomsat gelral

对于颜色, 我们维护一个 \(cnt\) 数组和 \(mx\), \(sum\) 变量用来统计当前的颜色出现次数/颜色出现最大次数/出现次数最大的颜色的和.

进行一个树链剖分, 计算答案的时候先遍历轻子树, 每计算完一个轻子树就清空, 然后计算重子树答案, 不清空, 最后再遍历轻子树, 将轻子树的颜色加入答案.

因为任意一条路径被切分为不超过 \(log{n}\) 条重链, 即每个点最多有 \(log{n}\) 个为轻儿子的祖先, 即每个点最多遍历到 \(log{n}\) 次.

复杂度为 \(O(nlog{n})\).

vector<int> col;
vector<LL> ans;
vector<vector<int>> gra;
vector<int> cnt;
int mx;
LL sum;

struct HLD {
    std::vector<int> fa, dep, siz, son, top;
    std::vector<std::vector<int>> &gra;
    // 1-index
    HLD(std::vector<std::vector<int>> &gra, int root)
        : fa(gra.size()), dep(gra.size()), siz(gra.size()), son(gra.size()), top(gra.size()), gra(gra) {
        auto dfs1 =  // fa, dep, siz, son
            [&](int u, int pre, auto self) -> void {
            fa[u] = pre, dep[u] = dep[pre] + 1, siz[u] = 1;
            for (auto v : gra[u]) {
                if (v != pre) {
                    self(v, u, self);
                    siz[u] += siz[v];
                    if (siz[v] > siz[son[u]])
                        son[u] = v;
                }
            }
        };
        dfs1(root, 0, dfs1);
        auto dfs2 =  // top
            [&](int u, int t, auto self) -> void {
            top[u] = t;
            for (auto v : gra[u]) {
                if (v != fa[u] && v != son[u])
                    self(v, v, self);  // 搜轻儿子
            }
            if (son[u])  // 搜重儿子
                self(son[u], t, self);
        };
        dfs2(root, root, dfs2);
    }
    int lca(int u, int v) const {
        while (top[u] != top[v]) {
            if (dep[top[u]] < dep[top[v]])
                v = fa[top[v]];
            else
                u = fa[top[u]];
        }

        return dep[u] < dep[v] ? u : v;
    }

    void clean(int u, int pre) {
        cnt[col[u]]--;
        for (auto v : gra[u]) {
            if (v != pre)
                clean(v, u);
        }
    }

    void add(int u, int pre, int son) {
        auto &times = cnt[col[u]];
        times++;
        if (times == mx)
            sum += col[u];
        else if (times > mx)
            mx = times, sum = col[u];
        for (auto v : gra[u])
            if (v != son && v != pre)
                add(v, u, son);
    }

    void work(int u, int pre, int isson) {
        for (auto v : gra[u]) {
            if (v != pre && v != son[u])
                work(v, u, 0);
        }
        if (son[u])
            work(son[u], u, 1);
        add(u, pre, son[u]);
        ans[u] = sum;
        if (!isson)
            clean(u, pre), mx = sum = 0;
    }
};

void solve() {
    int n;
    cin >> n;
    col.resize(n + 1);
    ans.resize(n + 1);
    gra.resize(n + 1);
    cnt.resize(n + 1);
    for (int i = 1; i <= n; i++)
        cin >> col[i];
    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        gra[u].push_back(v);
        gra[v].push_back(u);
    }
    HLD hld(gra, 1);
    hld.work(1, 0, 0);
    for (int i = 1; i <= n; i++) cout << ans[i] << ' ';
    cout << '\n';
}
,

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注