DSU ON TREE

一种在 O(nlogn) 复杂度内求树上问题的算法。

引入

假设现在我们有一道题,如下:

给定一棵 n 个节点的树,每一个节点上都有一个颜色 ci,求每一个节点的子树中不同颜色的个数。
1n2×105
i[1,n],1cin

我们之前应该有过区间求不同颜色个数的经验,就是维护一个 cnt 数组,然后每一次以 O(1) 的复杂度修改 cnt 数组来维护不同颜色个数,这使得我们可以用莫队来维护。

到了树上该怎么办?

树上莫队!

但是树上莫队太麻烦了,复杂度还带根号,我们选用 ——DSU ON TREE!

什么是 DSU ON TREE

DSU on tree,全称树上启发式合并。

这个 DSU,拆开应该是 Disjoint Sets under Union,也就是并查集。

并查集我们按秩合并的时候,会把小的子树合并到大的子树的根节点上作为其子树之一,时间复杂度是熟悉的 O(nlogn)

在做树上启发式合并的时候,我们也采用类似的思想,把小的子树上的信息向大的子树上的信息合并,从而也达到 O(nlogn) 的时间复杂度。

证明的话放到操作里面。

大致思想

我们就拿上面的题目为例。

同样的套路,我们对每一个子树维护一个 cnt 数组,并尝试通过在 DFS 的时候把当前节点的所有儿子的信息合并到当前节点上来减少空间复杂度。

我们遍历一个节点,统计上其信息的复杂度是 O(1) 的,那统计一整棵子树的时间复杂度是 O(sz(i)) 的。
假如说我们每一次遍历到一个节点的时候,我们遍历其每一个儿子,计算其每一个儿子的答案,期间将 cnt 数组全部清空,最后在把自己子树内所有节点遍历一遍求自己子树内的答案,可以将空间复杂度减少到 O(n)
这样的话,每一个节点都会被统计 O(dep(i)) 次,时间复杂度是 O(i=1ndep(i)) 的,可以被精心构造的数据卡掉。

还记得树剖吗?每一个点到根的路径上,切换轻重边的次数(或者直接就等价于轻边的数量)是 O(logn) 级别的。

那我们考虑每一次遇到轻边的时候再重新被统计,此时我们总共统计的次数是 O(nlogn) 级别的了。
这与我们预估的时间复杂度相符,应该就是这个算法了。

具体操作

我们首先对这个东西进行一次剖分。
正常的树链剖分可以得到 fadeptop 什么的,但我们这里只需要与轻重边相关的信息,同时也只需要 DFS 一次。

之后我们进行答案的统计,这里也是根据 DFS 来实现的。

每一次我们遍历到一个节点的时候,我们首先遍历其所有轻子树,并单独计算其答案。此时我们每一次换子树遍历的时候需要清空 cnt 数组。
然后我们遍历我们的重子树,计算其答案,并保留其对 cnt 数组的贡献。
然后我们遍历所有的轻子树,保留其对 cnt 数组的贡献。
最后,计算当前节点子树的答案。

我们可以看到,每一个轻子树都遍历了两边,每一个重子树都遍历了一遍。

修改对 cnt 数组的贡献的时候,如果觉得递归的复杂度太大,可以选择在剖分的时候记录一下 DFS 序,因为子树的 DFS 序一定是连续的一段,我们只需要遍历 DFS 序中一段连续的区间即可。

代码实现

就以刚才我们说的那道题为例,这里只放上去了两个 DFS 函数和维护答案的函数。

代码使用的是上面说的遍历 DFS 序上的一个区间的方式,所以需要解释一下几个数组的含义:
id 是当前点的 DFS 序,nw 是当前 DFS 序代表的点。
剩下的或与树剖中的意义一样,或已经解释过了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
int fa[N], son[N], dep[N], sz[N];
int id[N], nw[N], dfn;
void dfs1(int p, int father)
{
fa[p] = father, sz[p] = 1;
id[p] = ++dfn, nw[dfn] = p;
for(int i = h[p]; ~i; i = ne[i])
{
if(e[i] == father)continue;
dfs1(e[i], p);
sz[p] += sz[e[i]];
if(sz[e[i]] > sz[son[p]])son[p] = e[i];
}
}
int col[N];
int cnt[N];
int totcol;
int ans[N];
void add(int i)
{
cnt[i]++;
if(cnt[i] == 1)totcol++;
}
void del(int i)
{
if(cnt[i] == 1)totcol--;
cnt[i]--;
}
void dfs2(int p, bool keep)
{
for(int i = h[p]; ~i; i = ne[i])
{
if(e[i] == fa[p] || e[i] == son[p])continue;
dfs2(e[i], false);
}
if(son[p])dfs2(son[p], true);
for(int i = h[p]; ~i; i = ne[i])
{
if(e[i] == fa[p] || e[i] == son[p])continue;
for(int j = 0; j < sz[e[i]]; j++)
add(col[nw[id[e[i]] + j]]);
}
add(col[p]);
ans[p] = totcol;
if(!keep)
{
for(int j = 0; j < sz[p]; j++)
del(col[nw[id[p] + j]]);
}
}

例题