题目
数轴上n(n<=3e5)条线段,第i条线段[li,ri](0<=li<=ri<=3e5),
记第i条线段覆盖住的所有整点所在的集合为Si,定义以下三种运算:
1. A∪B为A和B集合的并集
2. A∩B为A和B集合的交集,
3. AB为A和B集合的对称差(即A∪B-A∩B)
现有一个长度为n-1的未确定符号的数组op,第i个符号可以是这三个符号中的一种,
求这种情况下,式子
的总和,
其中,|S|表示S集合内元素的个数,答案对998244353取模
思路来源
jiangly代码&Young_sean代码&自己乱搞的做法
题解1
对应代码1,是比较想补的动态dp做法,又称线段树维护矩阵优化dp,代码来自jiangly
感觉是个比较套路的做法,
考虑每个数的贡献,则每个集合只有包含和不包含两种情况,即这个值变成01,
操作也随之变成01之间的&(与)、|(或)、^(异或)
假设当前只有一条线段[l,r],其对应集合S,
1. 对于在[l,r]上的值x来说,之前经历了若干次操作(S1 op S2 op ...)后,
若当前式子中的值为1,则当其再op S时,因为x在S上,所以相当于op 1,
op 1时,1有&、|两种方式转移到1,有^一种方式转移到0(1 op 1)
若当前式子中的值为0,则当其再op S时,
op 1时,0有|、^两种方式转移到1,有&一种方式转移到0(0 op 1)
即有矩阵M1:
M1.a[0][0] = 1;
M1.a[0][1] = 2;
M1.a[1][0] = 1;
M1.a[1][1] = 2;
2.同理,若当前值x不在[l,r]上,之前经历了若干次操作(S1 op S2 op ...)后,
若当前式子中的值为1,则当其再op S时,因为x不在S上,所以相当于op 0,
op 0时,1有|、^两种方式转移到1,有&一种方式转移到0(1 op 0)
若当前式子中的值为0,则当其再op S时,
op 0时,0没有方式能转移到1,有|、&、^三种方式转移到0(0 op 0)
即有矩阵M0:
M0.a[0][0] = 3;
M0.a[0][1] = 0;
M0.a[1][0] = 1;
M0.a[1][1] = 2;
而对于数x来说,是否在S1对应的[l,r]相当于其初始状态的0/1值,
后续的S2-Sn,相当于通过若干道转移关卡(乘上若干个转移矩阵),
第i次根据x当前值以及x是否在Si上两个条件,决定如何转移
当增序从1遍历到3e5的时候,对于一个区间[l,r]来说,
在[l,r]时,只会乘矩阵M1;不在[l,r]上时,只会乘矩阵M0
矩阵变化,只会发生在进入l时以及离开r时,
遍历3e5个数后,均摊下来变化是2*n次,线段树维护矩阵即可
注意初始的01值,只能由x是否在[l1,r1]上决定,
不能钦定x=0/1后走转移矩阵,二者在答案上显然不同
题解2
组合数学,统计每个数x的贡献,仍然是考虑01序列
首先,对于每个点求覆盖其的最大区间的id,即统计其出现的最后位置
不妨,以下以一个长度为5的01序列表示,
1. 10000,即最后出现在首位置,因为0 op 0不会得出1,
所以,后续的操作都只能1 op 0,即每次只能选&、|两种操作,
种选法
2. 01000,即最后出现在第二位置,后三个符号同理只能选&、|
而第一个符号只能选让结果为1的那两个,0|1=1,0^1=1
注意到,11000同理,第一个符号还是只能选让结果为1的那两个,1|1=1,1&1=1
种选法
3. 00100,最后出现在[3,n]的位置x,[x,n]之间这n-x个符号,每个符号只能有&、|两种选法,
x和x-1之间的符号也只能有两种选法,视前面的值是0还是1,决定选|&还是|^
[1,x-1]之间的符号有三种选法,任选即可,无论终值是什么都可以被x和x-1之间的符号捞回来
对于位置x,有种选法,第二三两种情况可以统一
对于每个值x统计贡献,求和即可
题解3
自己的乱搞,实则是手玩发现的规律,后来找了找其合理性,
感觉其实和题解2本质相同,因为存在1在出现了之后需要一直保持着到结尾结束的情况,
所以统计这个1是在什么时候最后一次出现的,
即如果枚举第i位是最后一次出现的1时,强制给第i-1次操作后的值赋0
心得
其实后两种题解也涉及到一类经典问题
若干个区间覆盖,对于每个点求覆盖其的最大区间的id
若干个区间覆盖,求当前盖住的点的总个数(去重)
本质是同一个问题,允许离线,有很多种写法
1. 线段树维护区间或
2. 线段树维护区间覆盖,flag表示区间是否被完全覆盖,已经被盖住后续就忽略,均摊
3. 从小到大遍历x,l处加线段id,r处删线段id,set或优先队列维护当前最大的线段id
代码1(动态dp/线段树维护矩阵优化dp)
#include <bits/stdc++.h>
using i64 = long long;
constexpr int P = 998244353;
using i64 = long long;
// assume -P <= x < 2P
int norm(int x) {
if (x < 0) {
x += P;
}
if (x >= P) {
x -= P;
}
return x;
}
template<class T>
T power(T a, i64 b) {
T res = 1;
for (; b; b /= 2, a *= a) {
if (b % 2) {
res *= a;
}
}
return res;
}
struct Z {
int x;
Z(int x = 0) : x(norm(x)) {}
Z(i64 x) : x(norm(x % P)) {}
int val() const {
return x;
}
Z operator-() const {
return Z(norm(P - x));
}
Z inv() const {
assert(x != 0);
return power(*this, P - 2);
}
Z &operator*=(const Z &rhs) {
x = i64(x) * rhs.x % P;
return *this;
}
Z &operator+=(const Z &rhs) {
x = norm(x + rhs.x);
return *this;
}
Z &operator-=(const Z &rhs) {
x = norm(x - rhs.x);
return *this;
}
Z &operator/=(const Z &rhs) {
return *this *= rhs.inv();
}
friend Z operator*(const Z &lhs, const Z &rhs) {
Z res = lhs;
res *= rhs;
return res;
}
friend Z operator+(const Z &lhs, const Z &rhs) {
Z res = lhs;
res += rhs;
return res;
}
friend Z operator-(const Z &lhs, const Z &rhs) {
Z res = lhs;
res -= rhs;
return res;
}
friend Z operator/(const Z &lhs, const Z &rhs) {
Z res = lhs;
res /= rhs;
return res;
}
friend std::istream &operator>>(std::istream &is, Z &a) {
i64 v;
is >> v;
a = Z(v);
return is;
}
friend std::ostream &operator<<(std::ostream &os, const Z &a) {
return os << a.val();
}
};template<class Info,
class Merge = std::plus<Info>>
struct SegmentTree {
const int n;
const Merge merge;
std::vector<Info> info;
SegmentTree(int n) : n(n), merge(Merge()), info(4 << std::__lg(n)) {}
SegmentTree(std::vector<Info> init) : SegmentTree(init.size()) {
std::function<void(int, int, int)> build = [&](int p, int l, int r) {
if (r - l == 1) {
info[p] = init[l];
return;
}
int m = (l + r) / 2;
build(2 * p, l, m);
build(2 * p + 1, m, r);
pull(p);
};
build(1, 0, n);
}
void pull(int p) {
info[p] = merge(info[2 * p], info[2 * p + 1]);
}
void modify(int p, int l, int r, int x, const Info &v) {
if (r - l == 1) {
info[p] = v;
return;
}
int m = (l + r) / 2;
if (x < m) {
modify(2 * p, l, m, x, v);
} else {
modify(2 * p + 1, m, r, x, v);
}
pull(p);
}
void modify(int p, const Info &v) {
modify(1, 0, n, p, v);
}
Info rangeQuery(int p, int l, int r, int x, int y) {
if (l >= y || r <= x) {
return Info();
}
if (l >= x && r <= y) {
return info[p];
}
int m = (l + r) / 2;
return merge(rangeQuery(2 * p, l, m, x, y), rangeQuery(2 * p + 1, m, r, x, y));
}
Info rangeQuery(int l, int r) {
return rangeQuery(1, 0, n, l, r);
}
};
struct Matrix {
Z a[2][2];
};
Matrix operator+(const Matrix &a, const Matrix &b) {
Matrix c;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
for (int k = 0; k < 2; k++) {
c.a[i][j] += a.a[i][k] * b.a[k][j];
}
}
}
return c;
}
Matrix M0, M1;
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
M0.a[0][0] = 3;
M0.a[0][1] = 0;
M0.a[1][0] = 1;
M0.a[1][1] = 2;
M1.a[0][0] = 1;
M1.a[0][1] = 2;
M1.a[1][0] = 1;
M1.a[1][1] = 2;
int n;
std::cin >> n;
std::vector<int> l(n), r(n);
int max = 0;
for (int i = 0; i < n; i++) {
std::cin >> l[i] >> r[i];
max = std::max(max, r[i]);
}
max++;
std::vector<std::vector<int>> add(max), del(max);
for (int i = 0; i < n; i++) {
add[l[i]].push_back(i);
del[r[i]].push_back(i);
}
int a0 = 0;
SegmentTree<Matrix> seg(std::vector(n - 1, M0));
Z ans = 0;
for (int i = 0; i < max; i++) {
for (auto j : add[i]) {
if (j) {
seg.modify(j - 1, M1);
} else {
a0 = 1;
}
}
ans += seg.info[1].a[a0][1];
for (auto j : del[i]) {
if (j) {
seg.modify(j - 1, M0);
} else {
a0 = 0;
}
}
}
std::cout << ans << "\n";
return 0;
}
代码2(组合数学)
import random
import sys
import os
import math
from collections import Counter, defaultdict, deque
from functools import lru_cache, reduce
from itertools import accumulate, combinations, permutations
from heapq import nsmallest, nlargest, heapify, heappop, heappush
from io import BytesIO, IOBase
from copy import deepcopy
import threading
import bisect
BUFSIZE = 4096
class FastIO(IOBase):
newlines = 0
def __init__(self, file):
self._fd = file.fileno()
self.buffer = BytesIO()
self.writable = "x" in file.mode or "r" not in file.mode
self.write = self.buffer.write if self.writable else None
def read(self):
while True:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
if not b:
break
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines = 0
return self.buffer.read()
def readline(self):
while self.newlines == 0:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
self.newlines = b.count(b"\n") + (not b)
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines -= 1
return self.buffer.readline()
def flush(self):
if self.writable:
os.write(self._fd, self.buffer.getvalue())
self.buffer.truncate(0), self.buffer.seek(0)
class IOWrapper(IOBase):
def __init__(self, file):
self.buffer = FastIO(file)
self.flush = self.buffer.flush
self.writable = self.buffer.writable
self.write = lambda s: self.buffer.write(s.encode("ascii"))
self.read = lambda: self.buffer.read().decode("ascii")
self.readline = lambda: self.buffer.readline().decode("ascii")
sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout)
input = lambda: sys.stdin.readline().rstrip("\r\n")
class SortedList:
def __init__(self, iterable=[], _load=200):
"""Initialize sorted list instance."""
values = sorted(iterable)
self._len = _len = len(values)
self._load = _load
self._lists = _lists = [values[i:i + _load] for i in range(0, _len, _load)]
self._list_lens = [len(_list) for _list in _lists]
self._mins = [_list[0] for _list in _lists]
self._fen_tree = []
self._rebuild = True
def _fen_build(self):
"""Build a fenwick tree instance."""
self._fen_tree[:] = self._list_lens
_fen_tree = self._fen_tree
for i in range(len(_fen_tree)):
if i | i + 1 < len(_fen_tree):
_fen_tree[i | i + 1] += _fen_tree[i]
self._rebuild = False
def _fen_update(self, index, value):
"""Update `fen_tree[index] += value`."""
if not self._rebuild:
_fen_tree = self._fen_tree
while index < len(_fen_tree):
_fen_tree[index] += value
index |= index + 1
def _fen_query(self, end):
"""Return `sum(_fen_tree[:end])`."""
if self._rebuild:
self._fen_build()
_fen_tree = self._fen_tree
x = 0
while end:
x += _fen_tree[end - 1]
end &= end - 1
return x
def _fen_findkth(self, k):
"""Return a pair of (the largest `idx` such that `sum(_fen_tree[:idx]) <= k`, `k - sum(_fen_tree[:idx])`)."""
_list_lens = self._list_lens
if k < _list_lens[0]:
return 0, k
if k >= self._len - _list_lens[-1]:
return len(_list_lens) - 1, k + _list_lens[-1] - self._len
if self._rebuild:
self._fen_build()
_fen_tree = self._fen_tree
idx = -1
for d in reversed(range(len(_fen_tree).bit_length())):
right_idx = idx + (1 << d)
if right_idx < len(_fen_tree) and k >= _fen_tree[right_idx]:
idx = right_idx
k -= _fen_tree[idx]
return idx + 1, k
def _delete(self, pos, idx):
"""Delete value at the given `(pos, idx)`."""
_lists = self._lists
_mins = self._mins
_list_lens = self._list_lens
self._len -= 1
self._fen_update(pos, -1)
del _lists[pos][idx]
_list_lens[pos] -= 1
if _list_lens[pos]:
_mins[pos] = _lists[pos][0]
else:
del _lists[pos]
del _list_lens[pos]
del _mins[pos]
self._rebuild = True
def _loc_left(self, value):
"""Return an index pair that corresponds to the first position of `value` in the sorted list."""
if not self._len:
return 0, 0
_lists = self._lists
_mins = self._mins
lo, pos = -1, len(_lists) - 1
while lo + 1 < pos:
mi = (lo + pos) >> 1
if value <= _mins[mi]:
pos = mi
else:
lo = mi
if pos and value <= _lists[pos - 1][-1]:
pos -= 1
_list = _lists[pos]
lo, idx = -1, len(_list)
while lo + 1 < idx:
mi = (lo + idx) >> 1
if value <= _list[mi]:
idx = mi
else:
lo = mi
return pos, idx
def _loc_right(self, value):
"""Return an index pair that corresponds to the last position of `value` in the sorted list."""
if not self._len:
return 0, 0
_lists = self._lists
_mins = self._mins
pos, hi = 0, len(_lists)
while pos + 1 < hi:
mi = (pos + hi) >> 1
if value < _mins[mi]:
hi = mi
else:
pos = mi
_list = _lists[pos]
lo, idx = -1, len(_list)
while lo + 1 < idx:
mi = (lo + idx) >> 1
if value < _list[mi]:
idx = mi
else:
lo = mi
return pos, idx
def add(self, value):
"""Add `value` to sorted list."""
_load = self._load
_lists = self._lists
_mins = self._mins
_list_lens = self._list_lens
self._len += 1
if _lists:
pos, idx = self._loc_right(value)
self._fen_update(pos, 1)
_list = _lists[pos]
_list.insert(idx, value)
_list_lens[pos] += 1
_mins[pos] = _list[0]
if _load + _load < len(_list):
_lists.insert(pos + 1, _list[_load:])
_list_lens.insert(pos + 1, len(_list) - _load)
_mins.insert(pos + 1, _list[_load])
_list_lens[pos] = _load
del _list[_load:]
self._rebuild = True
else:
_lists.append([value])
_mins.append(value)
_list_lens.append(1)
self._rebuild = True
def discard(self, value):
"""Remove `value` from sorted list if it is a member."""
_lists = self._lists
if _lists:
pos, idx = self._loc_right(value)
if idx and _lists[pos][idx - 1] == value:
self._delete(pos, idx - 1)
def remove(self, value):
"""Remove `value` from sorted list; `value` must be a member."""
_len = self._len
self.discard(value)
if _len == self._len:
raise ValueError('{0!r} not in list'.format(value))
def pop(self, index=-1):
"""Remove and return value at `index` in sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
value = self._lists[pos][idx]
self._delete(pos, idx)
return value
def bisect_left(self, value):
"""Return the first index to insert `value` in the sorted list."""
pos, idx = self._loc_left(value)
return self._fen_query(pos) + idx
def bisect_right(self, value):
"""Return the last index to insert `value` in the sorted list."""
pos, idx = self._loc_right(value)
return self._fen_query(pos) + idx
def count(self, value):
"""Return number of occurrences of `value` in the sorted list."""
return self.bisect_right(value) - self.bisect_left(value)
def __len__(self):
"""Return the size of the sorted list."""
return self._len
def __getitem__(self, index):
"""Lookup value at `index` in sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
return self._lists[pos][idx]
def __delitem__(self, index):
"""Remove value at `index` from sorted list."""
pos, idx = self._fen_findkth(self._len + index if index < 0 else index)
self._delete(pos, idx)
def __contains__(self, value):
"""Return true if `value` is an element of the sorted list."""
_lists = self._lists
if _lists:
pos, idx = self._loc_left(value)
return idx < len(_lists[pos]) and _lists[pos][idx] == value
return False
def __iter__(self):
"""Return an iterator over the sorted list."""
return (value for _list in self._lists for value in _list)
def __reversed__(self):
"""Return a reverse iterator over the sorted list."""
return (value for _list in reversed(self._lists) for value in reversed(_list))
def __repr__(self):
"""Return string representation of sorted list."""
return 'SortedList({0})'.format(list(self))
def I():
return input()
def II():
return int(input())
def MI():
return map(int, input().split())
def LI():
return list(input().split())
def LII():
return list(map(int, input().split()))
def GMI():
return map(lambda x: int(x) - 1, input().split())
def LGMI():
return list(map(lambda x: int(x) - 1, input().split()))
n = II()
intervals = []
INT = 998244353
for _ in range(n):
intervals.append(LII())
nums = [n] * (300001)
tmp = SortedList(list(range(300001)))
for idx, (x, y) in enumerate(intervals[::-1]):
pos = tmp.bisect_left(x)
while pos < len(tmp) and tmp[pos] <= y:
nums[tmp[pos]] = idx
tmp.remove(tmp[pos])
ans = 0
for val in nums:
if val == n:
continue
if val > n-2:
res = pow(2, n-1, INT)
else:
res = pow(2, val + 1, INT) * pow(3, n - 1 - val - 1, INT)
ans = (ans + res) % INT
print(ans)
代码3(自己的乱搞)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int N=3e5+10,mod=998244353;
int n,l[N],r[N],ans;
struct node{
int cs,cov;
node(){
cs=0;cov=-1;
}
}tr[N*4];
void up(int p){
tr[p].cs=tr[p<<1].cs+tr[p<<1|1].cs;
}
void bld(int p,int l,int r){
tr[p].cs=tr[p].cov=0;
if(l==r){
return;
}
int mid=(l+r)>>1;
bld(p<<1,l,mid);
bld(p<<1|1,mid+1,r);
up(p);
}
void psd(int p,int l,int r){
if(tr[p].cov){
int mid=(l+r)/2;
tr[p<<1].cs=mid-l+1;
tr[p<<1|1].cs=r-mid;
tr[p<<1].cov=tr[p].cov;
tr[p<<1|1].cov=tr[p].cov;
tr[p].cov=0;
}
}
void upd(int p,int l,int r,int ql,int qr){
if(ql<=l&&r<=qr){
tr[p].cs=r-l+1;
tr[p].cov=1;
return;
}
psd(p,l,r);
int mid=(l+r)>>1;
if(ql<=mid)upd(p<<1,l,mid,ql,qr);
if(qr>mid)upd(p<<1|1,mid+1,r,ql,qr);
up(p);
}
int ask(int p,int l,int r,int ql,int qr) {
if(ql<=l&&r<=qr)return tr[p].cs;
psd(p,l,r);
int res=0,mid=(l+r)>>1;
if(ql<=mid)res+=ask(p<<1,l,mid,ql,qr);
if(qr>mid)res+=ask(p<<1|1,mid+1,r,ql,qr);
return res;
}
int modpow(int x,int n,int mod){
int res=1;
for(;n;n>>=1,x=1ll*x*x%mod){
if(n&1)res=1ll*res*x%mod;
}
return res;
}
int main(){
scanf("%d",&n);
int inv=modpow(3,mod-2,mod),three=1;
for(int i=1;i<=n;++i){
scanf("%d%d",&l[i],&r[i]);
if(i<=n-2)three=3ll*three%mod;
}
bld(1,0,3e5);
for(int i=n;i>=1;--i){
upd(1,0,3e5,l[i],r[i]);
if(i==2)continue;
//ans=(ans+1ll*three[max(0,i-3)]*two[n-i+1]%mod*ask(1,1,n,1,n)%mod)%mod;// /3*2
if(i>1)three=1ll*three*inv%mod;
three=2ll*three%mod;
//printf("ask:%d\n",ask(1,0,3e5,0,3e5));
ans=(ans+1ll*three*ask(1,0,3e5,0,3e5)%mod)%mod;
}
printf("%d\n",ans);
return 0;
}