原题传送门
就是按顺序先找一段全是1的子序列,再找一段全是2的子序列,再找一段全是1的子序列,再找一段全是2的子序列
这样我们就找出了一段诸如
111122221111222
111122221111222
111122221111222的子序列,把第二段和第三段翻转,得到
111...111222...222
111...111222...222
111...111222...222,就是答案
就是找出长度最长的四段和
令
s
u
m
1
i
sum1_i
sum1i表示前缀1的个数
s
u
m
2
i
sum2_i
sum2i表示后缀2的个数
枚举第一段的结尾
i
i
i,第二段的结尾
j
j
j,第三段的结尾
k
k
k,第四段的结尾当然是
n
n
n
然后第一段的开头是
1
1
1,第二段的开头是
i
+
1
i+1
i+1,第三段的开头是
j
+
1
j+1
j+1,第四段的开头是
k
+
1
k+1
k+1
而且
0
<
=
i
<
=
j
<
=
k
<
=
n
0<=i<=j<=k<=n
0<=i<=j<=k<=n
对于一组
(
i
,
j
,
k
)
(i,j,k)
(i,j,k)
答案是
s
u
m
1
i
+
s
u
m
1
k
−
s
u
m
1
j
+
s
u
m
2
k
+
1
+
s
u
m
2
i
+
1
−
s
u
m
2
j
+
1
=
(
s
u
m
1
i
+
s
u
m
2
i
+
1
)
+
(
s
u
m
1
k
+
s
u
m
2
k
+
1
−
(
s
u
m
1
j
+
s
u
m
2
j
+
1
)
)
sum1_i+sum1_k-sum1_j+sum2_{k+1}+sum2_{i+1}-sum2_{j+1}=(sum1_i+sum2_{i+1})+(sum1_k+sum2_{k+1}-(sum1_j+sum2_{j+1}))
sum1i+sum1k−sum1j+sum2k+1+sum2i+1−sum2j+1=(sum1i+sum2i+1)+(sum1k+sum2k+1−(sum1j+sum2j+1))
所以可以维护前缀后缀
s
u
m
1
x
+
s
u
m
2
x
+
1
sum1_x+sum2_{x+1}
sum1x+sum2x+1的最大值
枚举
j
j
j
时间复杂度 O ( n l o g n ) O(nlogn) O(nlogn)
Code:
#include <bits/stdc++.h>
#define maxn 2010
using namespace std;
int n, a[maxn], sum1[maxn], sum2[maxn], tree1[maxn], tree2[maxn];
inline int read(){
int s = 0, w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
int lowbit(int x){ return x & -x; }
void update1(int x, int y){ for (; x <= n; x += lowbit(x)) tree1[x] = max(tree1[x], y); }
void update2(int x, int y){ for (; x; x -= lowbit(x)) tree2[x] = max(tree2[x], y); }
int query1(int x){ int s = 0; for (; x; x -= lowbit(x)) s = max(s, tree1[x]); return s; }
int query2(int x){ int s = 0; for (; x <= n; x += lowbit(x)) s = max(s, tree2[x]); return s; }
int main(){
n = read();
for (int i = 1; i <= n; ++i) a[i] = read();
for (int i = 1; i <= n; ++i) sum1[i] = sum1[i - 1] + (a[i] == 1);
for (int i = n; i; --i) sum2[i] = sum2[i + 1] + (a[i] == 2);
for (int i = 1; i <= n; ++i) update1(i, sum1[i] + sum2[i + 1]), update2(i, sum1[i] + sum2[i + 1]);
update1(1, sum2[1]);
int ans = 0;
for (int i = 1; i <= n; ++i) ans = max(ans, query1(i) + query2(i) - sum1[i] - sum2[i + 1]);
printf("%d\n", ans);
return 0;
}
但是还有一种更优越的做法
令
d
p
i
,
1
/
2
/
3
/
4
dp_{i,1/2/3/4}
dpi,1/2/3/4表示到第
i
i
i个,前1/2/3/4段的最长长度
d
p
i
,
1
=
d
p
i
−
1
,
1
+
(
x
=
=
1
)
dp_{i,1}=dp_{i-1,1}+(x==1)
dpi,1=dpi−1,1+(x==1)
d
p
i
,
2
=
m
a
x
(
d
p
i
−
1
,
1
,
d
p
i
−
1
,
2
+
(
x
=
=
2
)
)
dp_{i,2}=max(dp_{i-1,1},dp_{i-1,2}+(x==2))
dpi,2=max(dpi−1,1,dpi−1,2+(x==2))
d
p
i
,
3
=
m
a
x
(
d
p
i
−
1
,
2
,
d
p
i
−
1
,
3
+
(
x
=
=
1
)
)
dp_{i,3}=max(dp_{i-1,2},dp_{i-1,3}+(x==1))
dpi,3=max(dpi−1,2,dpi−1,3+(x==1))
d
p
i
,
4
=
m
a
x
(
d
p
i
−
1
,
3
,
d
p
i
−
1
,
4
+
(
x
=
=
2
)
)
dp_{i,4}=max(dp_{i-1,3},dp_{i-1,4}+(x==2))
dpi,4=max(dpi−1,3,dpi−1,4+(x==2))
然后可以把第一维弄掉
这样空间时间全部
O
(
n
)
O(n)
O(n)
Code:
#include <bits/stdc++.h>
using namespace std;
int n, dp[5];
inline int read(){
int s = 0, w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
int main(){
n = read();
int x;
for (int i = 1; i <= n; ++i){
x = read();
dp[1] += (x == 1);
dp[2] = max(dp[1], dp[2] + (x == 2));
dp[3] = max(dp[2], dp[3] + (x == 1));
dp[4] = max(dp[3], dp[4] + (x == 2));
}
printf("%d\n", dp[4]);
return 0;
}