P8290 [省选联考 2022] 填树

题解

Posted by WrongAnswer_90's Blog on April 16, 2024

My Blogs

P8290 [省选联考 2022] 填树

很有意思的拉插优化 DP。

首先可以枚举 $L$ 来限制选的数的值域在 $L,L+k$ 中。然后进行树上 DP:设 $v_i$ 表示当前点 $i$ 能填多少种数,$w_i$ 表示当前点 $i$ 能填的数的和。

$f_i$ 表示当前 $i$ 子树内的所有合法根链数量,$g_i$ 表示 $i$ 子树内所有根链的权值之和。假设在合并子树 $to$,转移方程:

\[\begin{aligned} v_i&=\max(0,\min(r_i,L+k)-\max(l_i,L)+1)\\ w_i&=\max(0,\frac{(\min(r_i,L+k)+\max(l_i,L))v_i}2)\\ ans1&\rightarrow ans1+f_if_{to}\\ ans2&\rightarrow ans2+f_ig_{to}+f_{to}g_i\\ f_i&\rightarrow f_i+v_if_{to}\\ g_i&\rightarrow g_i+v_ig_{to}+f_{to}w_i \end{aligned}\]

但是这样会算重,比如一个最小值为 $3$,最大值为 $5$ 的方案,在 $k=3$ 的时候会被算两次($L=2,R=5$ 时被算一次,$L=3,R=6$ 时被算一次)。可以在 DP 中多开一维记录最小值是否取到了 $L$,但是略显繁琐,常数也并不优秀。

所以考虑容斥,算 $2V$ 遍,每次用 $(L,L+k)$ 的答案减去 $(L+1,L+k)$ 的答案,这样得到的就是最小值恰好为 $L$ 的答案。

一次 DP 复杂度是 $\mathcal O(n)$,总复杂度 $\mathcal O(nV)$。但是还不够。

观察 $ans1$ 和 $ans2$,如果用分段法把 $\min$ 和 $\max$ 拆开,则他们在同一段下一定都是关于 $V$ 的不超过 $n+1$ 次的多项式。

对于 $ans1$,其是由若干个 $v$ 乘起来的,而 $v$ 要么是常数,要么是关于 $L$ 的单项式,所以次数不会超过 $n$。

对于 $ans2$,是由一个 $w$ 和多个 $v$ 乘起来的,而 $w$ 可能是一个关于 $L$ 的二次多项式,所以最高次数是 $n+1$。下面只考虑处理 $ans2$,$ans1$ 同理。

这样考虑设一个连续段 $l,r$,设 $ans_i$ 为 $L=i+l-1$ 时的 $ans2$。有结论:把 $ans_i$ 做一遍前缀和,得到的 $ans_i$ 是关于 $L$ 的不超过 $n+2$ 次的多项式。

考虑设 $F(x)=\sum_{i=0}^na_ix^i$,则做一遍前缀和之后 $S’(x)=\sum_{i=0}^na_i\sum_{j=1}^xx^j$,最高次项是一个 $n+1$ 次的自然数幂和,所以得到的前缀和最高次不超过 $n+2$。

所以可以只求 $n+3$ 个点值,用拉插公式求出在 $L=r$ 时前缀和的点值。连续段个数是 $\mathcal O(n)$ 的,每次需要做 $\mathcal O(n)$ 次 DP(来确定前 $n+3$ 个点值),每次 DP 复杂度是 $\mathcal O(n)$,拉插复杂度不是瓶颈,总复杂度 $\mathcal O(n^3)$。注意常数优化。

另一种做法:可以考虑枚举了连续段之后不暴力拉插,而是 DP 的时候直接多项式。因为次数不会超过 $siz_i$,所以暴力做复杂度是树上背包的 $\mathcal O(n^2)$。然后得到了一个 $\mathcal O(n)$ 次的多项式,想办法求出它的前缀和应该也能做,但是感觉不如拉插简洁和直接。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
int n,m,ans1,ans2,sum1,sum2,cnt,len,L,R,v[210],v2[210],fr[210];
int a[210],b[210],f[210],g[210],head[210],to[210],nex[210],numa[810];
vi T[210];
inline void add(int x,int y){to[++cnt]=y,nex[cnt]=head[x],head[x]=cnt;}
void dfs0(int x,int fa){for(auto to:T[x])if(to!=fa)add(x,to),dfs0(to,x);}
void dfs(int x)
{
	int v1=0,v2=0;
	if(min(R,b[x])<max(L,a[x]))f[x]=g[x]=0;
	else
	f[x]=v1=max(0,min(R,b[x])-max(L,a[x])+1),
	g[x]=v2=Cmul(min(R,b[x])+max(L,a[x]),f[x],(MOD+1)>>1);
	Madd(sum1,f[x]),Madd(sum2,g[x]);
	for(int i=head[x];i;i=nex[i])
	{
		dfs(to[i]);
		if(!v1)continue;
		Madd(sum1,Cmul(f[x],f[to[i]])),Madd(sum2,Cmul(f[x],g[to[i]]),Cmul(g[x],f[to[i]]));
		Madd(g[x],Cadd(Cmul(f[to[i]],v2),Cmul(v1,g[to[i]]))),Madd(f[x],Cmul(v1,f[to[i]]));
	}
}
inline void mian()
{
	read(n,m),fr[0]=1;int x,y,rr=-inf;L=inf;
	for(int i=1;i<=204;++i)fr[i]=Cmul(fr[i-1],i);
	for(int i=1;i<=n;++i)read(a[i],b[i]),Mmin(L,a[i]),Mmax(rr,b[i]);
	for(int i=1;i<n;++i)read(x,y),T[x].eb(y),T[y].eb(x);
	dfs0(1,0),R=L+m-1;
	for(int i=1;i<=n;++i)numa[++len]=a[i],numa[++len]=b[i]-m+1,numa[++len]=b[i]+1,numa[++len]=a[i]-m;
	sort(numa+1,numa+1+len),len=unique(numa+1,numa+1+len)-numa-1;
	for(int i=1;i<len;++i)
	{
		if(numa[i+1]-numa[i]<=n+3)
		{
			L=numa[i],R=L+m;
			while(L<numa[i+1])sum1=sum2=0,dfs(1),Madd(ans1,sum1),Madd(ans2,sum2),++L,++R;
			continue;
		}
		for(int j=1;j<=n+3;++j)
		{
			L=numa[i]-1+j,R=numa[i]-1+j+m,sum1=sum2=0,dfs(1);
			v[j]=Cadd(v[j-1],sum1),v2[j]=Cadd(v2[j-1],sum2);
		}
		int tmp=1,X=numa[i+1]-1;
		for(int j=1;j<=n+3;++j)Mmul(tmp,X-(numa[i]+j-1));
		for(int j=1;j<=n+3;++j)
		{
			int va=fr[j-1];
			if((n+3-j)&1)Mmul(va,Cdel(0,fr[n+3-j]));
			else Mmul(va,fr[n+3-j]);
			Madd(ans1,Cmul(v[j],tmp,power(Cmul(va,X-(numa[i]+j-1)),MOD-2)));
			Madd(ans2,Cmul(v2[j],tmp,power(Cmul(va,X-(numa[i]+j-1)),MOD-2)));
		}
	}
	--m;
	for(int i=1;i<len;++i)
	{
		if(numa[i+1]-numa[i]<=n+3)
		{
			L=numa[i],R=L+m;
			while(L<numa[i+1])sum1=sum2=0,dfs(1),Mdel(ans1,sum1),Mdel(ans2,sum2),++L,++R;
			continue;
		}
		for(int j=1;j<=n+3;++j)
		{
			L=numa[i]-1+j,R=numa[i]-1+j+m,sum1=sum2=0,dfs(1);
			v[j]=Cadd(v[j-1],sum1),v2[j]=Cadd(v2[j-1],sum2);
		}
		int tmp=1,X=numa[i+1]-1;
		for(int j=1;j<=n+3;++j)Mmul(tmp,X-(numa[i]+j-1));
		for(int j=1;j<=n+3;++j)
		{
			int va=fr[j-1];
			if((n+3-j)&1)Mmul(va,Cdel(0,fr[n+3-j]));
			else Mmul(va,fr[n+3-j]);
			Mdel(ans1,Cmul(v[j],tmp,power(Cmul(va,X-(numa[i]+j-1)),MOD-2)));
			Mdel(ans2,Cmul(v2[j],tmp,power(Cmul(va,X-(numa[i]+j-1)),MOD-2)));
		}
	}
	write(ans1,'\n',ans2);
}