P5901 [IOI2009] Regions

题解

Posted by WrongAnswer_90's Blog on October 26, 2023

P5901 [IOI2009] Regions

更好的阅读体验

根号分治,过掉不难,但是想 $\mathcal O(n\sqrt n)$ 还是有一些思维含量的。

经过思考,发现 polylog 十分困难,考虑根号的算法。

首先有一种暴力:预处理两两颜色间的答案,$\mathcal O(1)$ 查询。首先枚举颜色数,然后每种颜色 $\mathcal O(n)$ 扫一遍,这样预处理的复杂度是 $\mathcal O(n\times \mathrm{处理的颜色数})$。直接硬预处理显然过不去,所以可以考虑对出现次数较多的颜色预处理。设块长为 $B$,出现次数大于 $B$ 的颜色称为大块,否则为小块。

对于大块与大块,大块与小块,小块与大块之间的答案还是想上面一样预处理,预处理复杂度是 $\mathcal O(n\times \dfrac{n}{B})$。

对于小块,内部点数是 $\mathcal O(B)$ 级别的,树上的祖先后代问题用 dfn 序转化为区间问题:

$\sum_{i\in col_x}\sum_{j\in col_y}[in_i\leq dfn_j\leq out_i]$

直接做有 $in,out$ 两个限制,但是一个 $dfn$ 不可能即 $>out_i$ 有 $<in_i$,所以可以容斥一下,用总的点对数减去不合法的点对数,问题变为

$\sum_{i\in col_x}\sum_{j\in col_y}[dfn_j<in_i]$

可以树状数组做到 $B\log n$,但是更优秀的做法是预处理排序好的 $in$ 和 $dfn$,用类似归并排序求逆序对的思路双指针解决,对于 $>out_i$ 的部分也是同理。

总复杂度 $\mathcal O(n\times \dfrac{n}{B}+nB)$,$B$ 取 $\sqrt n$ 时复杂度为 $\mathcal O(n\sqrt n)$。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
	int n,m,q,cnt,B=360,tot,dfn[200001],nfd[200001],pos[200001],sum[200001],a[200001],head[200001],to[200001],nex[200001],siz[25001],ans1[160][25001],ans2[160][25001];
	inline void add(int x,int y){to[++cnt]=y,nex[cnt]=head[x],head[x]=cnt;}
	vector<int> larg,ve[25001],ve2[25001];
	void dfs(int k){dfn[k]=++tot;for(int i=head[k];i;i=nex[i])dfs(to[i]);nfd[k]=tot;}
	void dfs1(int k){for(int i=head[k];i;i=nex[i])dfs1(to[i]),sum[k]+=sum[to[i]];}
	void dfs2(int k){for(int i=head[k];i;i=nex[i])sum[to[i]]+=sum[k],dfs2(to[i]);}
	inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
	inline bool cmp2(int x,int y){return nfd[x]<nfd[y];}
	inline void mian()
	{
		cin>>n>>m>>q>>a[1],++siz[a[1]],ve[a[1]].eb(1);int x,y;
		for(int i=2;i<=n;++i)cin>>x>>a[i],++siz[a[i]],ve[a[i]].eb(i),add(x,i);
		for(int i=1;i<=m;++i)if(siz[i]>B)pos[i]=larg.size(),larg.eb(i);
		dfs(1);
		for(int i=1;i<=m;++i)ve[i].eb(0),ve2[i]=ve[i],sort(ve2[i].begin(),ve2[i].end(),cmp2),sort(ve[i].begin(),ve[i].end(),cmp);
		for(int i=0;i<larg.size();++i)
		{
			for(auto j:ve[larg[i]])++sum[j];
			dfs1(1);
			for(int j=1;j<=n;++j)ans1[i][a[j]]+=sum[j];
			memset(sum,0,sizeof(sum));
			for(auto j:ve[larg[i]])++sum[j];
			dfs2(1);
			for(int j=1;j<=n;++j)ans2[i][a[j]]+=sum[j];
			memset(sum,0,sizeof(sum));
		}
		while(q--)
		{
			cin>>x>>y;
			if(pos[x])cout<<ans2[pos[x]][y]<<endl;
			else if(pos[y])cout<<ans1[pos[y]][x]<<endl;
			else
			{
				int ans=siz[x]*siz[y];
				for(int i=1,j=1;i<=siz[x]&&j<=siz[y];++j)
				{
					while(i<=siz[x]&&dfn[ve[x][i]]<=dfn[ve[y][j]])++i;
					ans-=siz[x]-i+1;
				}
				for(int i=siz[x],j=siz[y];i>0&&j>0;--j)
				{
					while(i>0&&nfd[ve2[x][i]]>=dfn[ve[y][j]])--i;
					ans-=i;
				}
				cout<<ans<<endl;
			}
			fflush(stdout);
		}
	}