区间限制最大值最小值, 区间加, 查询区间最小值最大值,区间和


平均每次操作 \(O(\log^2{n})\).

class LSGT {
public:
    struct tag {
        LL add, Min, Max;
        tag() : add(0), Min(INF), Max(-INF) {}
        tag(LL add) : add(add), Min(INF), Max(-INF) {}
        tag(LL add, LL Min, LL Max) : add(add), Min(Min), Max(Max) {}
        bool empty() const {
            return add == 0 && Min == INF && Max == -INF;
        }
        void apply(const tag &o) {
            add += o.add;
            if (Min != INF)
                Min += o.add;
            if (Max != -INF)
                Max += o.add;
            if (Min <= o.Max) {
                Min = Max = o.Max;
            } else if (Max >= o.Min) {
                Min = Max = o.Min;
            } else {
                Min = min(o.Min, Min);
                Max = max(Max, o.Max);
            }
        }
    };
    struct info {
        LL cntM1, sum, M1, M2, len, cntm1, m1, m2;
        info() : cntM1(1), sum(0), M1(0), M2(-INF), len(1), cntm1(1), m1(0), m2(INF) {}
        info(LL x) : cntM1(1), sum(x), M1(x), M2(-INF), len(1), cntm1(1), m1(x), m2(INF) {}
        info(LL cntM1, LL sum, LL M1, LL M2, LL len, LL cntm1, LL m1, LL m2) : cntM1(cntM1), sum(sum), M1(M1), M2(M2), len(len), cntm1(cntm1), m1(m1), m2(m2) {}
        friend info operator+(const info &a, const info &b) {
            info res;
            res.len = a.len + b.len;
            res.sum = a.sum + b.sum;
            if (a.M1 == b.M1) {
                res.cntM1 = a.cntM1 + b.cntM1;
                res.M1 = a.M1;
                res.M2 = max(a.M2, b.M2);
            } else if (a.M1 > b.M1) {
                res.cntM1 = a.cntM1;
                res.M1 = a.M1;
                res.M2 = max(a.M2, b.M1);
            } else {
                res.cntM1 = b.cntM1;
                res.M1 = b.M1;
                res.M2 = max(b.M2, a.M1);
            }
            if (a.m1 == b.m1) {
                res.cntm1 = a.cntm1 + b.cntm1;
                res.m1 = a.m1;
                res.m2 = min(a.m2, b.m2);
            } else if (a.m1 < b.m1) {
                res.cntm1 = a.cntm1;
                res.m1 = a.m1;
                res.m2 = min(a.m2, b.m1);
            } else {
                res.cntm1 = b.cntm1;
                res.m1 = b.m1;
                res.m2 = min(a.m1, b.m2);
            }
            return res;
        }
        void apply(const tag &o) {
            sum += o.add * len;
            M1 += o.add;
            M2 += o.add;
            m1 += o.add;
            m2 += o.add;
            if (M1 == m1) {
                if (M1 > o.Min) {
                    sum -= (M1 - o.Min) * cntM1;
                    m1 = M1 = o.Min;
                }
                if (m1 < o.Max) {
                    sum -= (m1 - o.Max) * cntm1;
                    M1 = m1 = o.Max;
                }
                return;
            }
            if (o.Min == o.Max) {
                sum = len * o.Min;
                M1 = m1 = o.Min;
                cntm1 = cntM1 = len;
                return;
            }
            if (M1 > o.Min) {
                sum -= (M1 - o.Min) * cntM1;
                if (M1 == m2)
                    m2 = M1 = o.Min;
                else
                    M1 = o.Min;
            }
            if (m1 < o.Max) {
                sum -= (m1 - o.Max) * cntm1;
                if (m1 == M2)
                    m1 = M2 = o.Max;
                else
                    m1 = o.Max;
            }
        }
    };

private:
    std::vector<info> node;
    std::vector<tag> ta;
    int siz;
    void build(int idx, int l, int r) {
        if (l == r)
            return;
        int mid = (l + r) >> 1;
        build(idx << 1, l, mid), build(idx << 1 | 1, mid + 1, r);
        node[idx] = node[idx << 1] + node[idx << 1 | 1];
    }
    template <typename T>
    void build(int idx, int l, int r, const std::vector<T> &vec) {
        if (l == r) {
            node[idx] = vec[l];
            return;
        }
        int mid = (l + r) >> 1;
        build(idx << 1, l, mid, vec), build(idx << 1 | 1, mid + 1, r, vec);
        node[idx] = node[idx << 1] + node[idx << 1 | 1];
    }
    void apply(int idx) {
        if (ta[idx].empty())
            return;
        ta[idx << 1].apply(ta[idx]);
        ta[idx << 1 | 1].apply(ta[idx]);
        node[idx << 1].apply(ta[idx]);
        node[idx << 1 | 1].apply(ta[idx]);
        ta[idx] = {};
    }
    void modify(int idx, int l, int r, int ql, int qr, const tag &add) {
        if (!add.add && add.Min >= node[idx].M1 && add.Max <= node[idx].m1)
            return;
        if (ql <= l && qr >= r && add.Min > node[idx].M2 && add.Max < node[idx].m2) {
            ta[idx].apply(add);
            node[idx].apply(add);
            return;
        }
        apply(idx);
        int mid = (l + r) >> 1;
        if (ql <= mid)
            modify(idx << 1, l, mid, ql, qr, add);
        if (qr > mid)
            modify(idx << 1 | 1, mid + 1, r, ql, qr, add);
        node[idx] = node[idx << 1] + node[idx << 1 | 1];
    }
    info query(int idx, int l, int r, int ql, int qr) {
        if (ql <= l && qr >= r)
            return node[idx];
        apply(idx);
        int mid = (l + r) >> 1;
        if (qr <= mid)
            return query(idx << 1, l, mid, ql, qr);
        else if (ql > mid)
            return query(idx << 1 | 1, mid + 1, r, ql, qr);
        else
            return query(idx << 1, l, mid, ql, qr) + query(idx << 1 | 1, mid + 1, r, ql, qr);
    }

public:
    LSGT(const int size) : node(size << 2), ta(size << 2), siz(size) {
        build(1, 1, siz);
    }
    template <typename T>
    LSGT(const std::vector<T> &vec) : node(vec.size() << 2), ta(vec.size() << 2), siz(vec.size() - 1) {
        build(1, 1, siz, vec);
    }
    void modify(int ql, int qr, const tag &add) {
        modify(1, 1, siz, ql, qr, add);
    }
    info query(int ql, int qr) {
        return query(1, 1, siz, ql, qr);
    }
};

应用:

void solve() {
    int n, m;
    cin >> n;
    vector<int> arr(n + 1);
    for (int i = 1; i <= n; i++) cin >> arr[i];
    LSGT tr(arr);
    cin >> m;
    while (m--) {
        int op;
        cin >> op;
        int l, r;
        cin >> l >> r;
        if (op == 1) {
            int x;
            cin >> x;
            // 区间加 x
            tr.modify(l, r, {x});
        } else if (op == 2) {
            int x;
            cin >> x;
            // 区间取max(x,a[i])
            tr.modify(l, r, {0, INF, x});
        } else if (op == 3) {
            int x;
            cin >> x;
            // 区间取min(x,a[i])
            tr.modify(l, r, {0, x, -INF});
        } else if (op == 4) {
            // 区间求和
            cout << tr.query(l, r).sum << '\n';
        } else if (op == 5) {
            // 区间最大值
            cout << tr.query(l, r).M1 << '\n';
        } else if (op == 6) {
            // 区间最小值
            cout << tr.query(l, r).m1 << '\n';
        }
    }
}

signed main() {
#ifdef LOCAL
    clock_t tttttttt = clock();
    freopen("in.txt", "r", stdin);
#endif
#ifndef LOCAL
    ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
#endif
    //*****************************************************************
    int t = 1;
    // cin >> t;
    while (t--) solve();
//*****************************************************************
#ifdef LOCAL
    cerr << "Time Used: " << fixed << setprecision(3) << (clock() - tttttttt) / (CLOCKS_PER_SEC / 1000.0) << " ms" << endl;
#endif
    return 0;
}
,

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注