平均每次操作 \(O(\log^2{n})\).
class LSGT {
public:
struct tag {
LL addmax, addmax_, add, add_;
tag() : addmax(0), addmax_(-INF), add(0), add_(-INF) {}
tag(LL add) : addmax(add), addmax_(add), add(add), add_(add) {}
tag(LL addmax, LL addmax_, LL add, LL add_) : addmax(addmax), addmax_(addmax_), add(add), add_(add_) {}
bool empty() const {
return addmax == 0 && addmax_ == -INF && add == 0 && add_ == -INF;
}
void apply(const tag &o) {
add_ = max(add_, add + o.add_);
addmax_ = max(addmax_, addmax + o.addmax_);
add += o.add;
addmax += o.addmax;
}
};
struct info {
LL cntM1, sum, M1, M2, len, b;
info() : cntM1(1), sum(0), M1(-INF), M2(-INF), len(1), b(0) {}
info(LL x) : cntM1(1), sum(x), M1(x), M2(-INF), len(1), b(x) {}
info(LL cntM1, LL sum, LL M1, LL M2, LL len, LL b) : cntM1(cntM1), sum(sum), M1(M1), M2(M2), len(len), b(b) {}
friend info operator+(const info &a, const info &b) {
info res;
res.len = a.len + b.len;
res.sum = a.sum + b.sum;
res.b = max(a.b, b.b);
if (a.M1 == b.M1) {
res.cntM1 = a.cntM1 + b.cntM1;
res.M1 = a.M1;
res.M2 = max(a.M2, b.M2);
} else if (a.M1 > b.M1) {
res.cntM1 = a.cntM1;
res.M1 = a.M1;
res.M2 = max(a.M2, b.M1);
} else {
res.cntM1 = b.cntM1;
res.M1 = b.M1;
res.M2 = max(b.M2, a.M1);
}
return res;
}
void apply(const tag &o) {
sum += o.add * len - o.add * cntM1 + o.addmax * cntM1;
b = max(b, M1 + o.addmax_);
M1 += o.addmax;
M2 += o.add;
}
};
private:
std::vector<info> node;
std::vector<tag> ta;
int siz;
void build(int idx, int l, int r) {
if (l == r)
return;
int mid = (l + r) >> 1;
build(idx << 1, l, mid), build(idx << 1 | 1, mid + 1, r);
node[idx] = node[idx << 1] + node[idx << 1 | 1];
}
template <typename T>
void build(int idx, int l, int r, const std::vector<T> &vec) {
if (l == r) {
node[idx] = vec[l];
return;
}
int mid = (l + r) >> 1;
build(idx << 1, l, mid, vec), build(idx << 1 | 1, mid + 1, r, vec);
node[idx] = node[idx << 1] + node[idx << 1 | 1];
}
void apply(int idx) {
if (ta[idx].empty())
return;
LL maxx = max(node[idx << 1].M1, node[idx << 1 | 1].M1);
if (node[idx << 1].M1 == maxx)
node[idx << 1].apply(ta[idx]), ta[idx << 1].apply(ta[idx]);
else
node[idx << 1].apply({ta[idx].add, ta[idx].add_, ta[idx].add, ta[idx].add_}), ta[idx << 1].apply({ta[idx].add, ta[idx].add_, ta[idx].add, ta[idx].add_});
if (node[idx << 1 | 1].M1 == maxx)
node[idx << 1 | 1].apply(ta[idx]), ta[idx << 1 | 1].apply(ta[idx]);
else
node[idx << 1 | 1].apply({ta[idx].add, ta[idx].add_, ta[idx].add, ta[idx].add_}), ta[idx << 1 | 1].apply({ta[idx].add, ta[idx].add_, ta[idx].add, ta[idx].add_});
ta[idx] = {};
}
void modify(int idx, int l, int r, int ql, int qr, tag add) {
if (add.add == 0 && add.addmax >= node[idx].M1)
return;
if (ql <= l && qr >= r && (add.add != 0 || add.addmax > node[idx].M2)) {
if (add.add == 0) {
add.addmax = -node[idx].M1 + add.addmax;
add.addmax_ = add.addmax;
}
ta[idx].apply(add);
node[idx].apply(add);
return;
}
apply(idx);
int mid = (l + r) >> 1;
if (ql <= mid)
modify(idx << 1, l, mid, ql, qr, add);
if (qr > mid)
modify(idx << 1 | 1, mid + 1, r, ql, qr, add);
node[idx] = node[idx << 1] + node[idx << 1 | 1];
}
info query(int idx, int l, int r, int ql, int qr) {
if (ql <= l && qr >= r)
return node[idx];
apply(idx);
int mid = (l + r) >> 1;
if (qr <= mid)
return query(idx << 1, l, mid, ql, qr);
else if (ql > mid)
return query(idx << 1 | 1, mid + 1, r, ql, qr);
else
return query(idx << 1, l, mid, ql, qr) + query(idx << 1 | 1, mid + 1, r, ql, qr);
}
public:
LSGT(const int size) : node(size << 2), ta(size << 2), siz(size) {
build(1, 1, siz);
}
template <typename T>
LSGT(const std::vector<T> &vec) : node(vec.size() << 2), ta(vec.size() << 2), siz(vec.size() - 1) {
build(1, 1, siz, vec);
}
void modify(int ql, int qr, const tag &add) {
modify(1, 1, siz, ql, qr, add);
}
info query(int ql, int qr) {
return query(1, 1, siz, ql, qr);
}
};
void solve() {
int n, m;
cin >> n;
cin >> m;
vector<int> arr(n + 1);
for (int i = 1; i <= n; i++) cin >> arr[i];
LSGT tr(arr);
while (m--) {
int op;
cin >> op;
int l, r;
cin >> l >> r;
if (op == 1) {
int x;
cin >> x;
// 区间加, 注意为 0 不要加
if (x)
tr.modify(l, r, x);
} else if (op == 2) {
int x;
cin >> x;
// 区间取 min(x, a[i])
tr.modify(l, r, {x, 0, 0, 0});
} else if (op == 3) {
// 区间求和
cout << tr.query(l, r).sum << '\n';
} else if (op == 4) {
// 区间查最大值
cout << tr.query(l, r).M1 << '\n';
} else if (op == 5) {
// 区间查历史最大值
cout << tr.query(l, r).b << '\n';
}
}
}