P3574 [POI2014] FAR-FarmCraft 题解


true
FarmCraft
POI 2014
省选/NOI-
#9d3dcf
  • Luogu P3574
  • BZOJ #3829

题目需要我们从一号节点开始遍历整棵树,并给每一个节点记录下遍历到其的时间。
同时每一个节点又有一个自己的倒计时,当遍历到其的时候就开始计时。
我们需要求出所有节点中倒计时结束最慢的那一个的结束时间,并最小化之。

我们不难想象出一个简单的DP式子来求得最终的答案:

对于节点 $u$ 的子树 $v$,我们有如下的DP式子:

$\operatorname{dp}[u]=\max(\operatorname{dp}[u],\operatorname{dp}[v]+\operatorname{sz}[u]+1)$

其中 $\operatorname{dp}[u]$ 代表当前子树中最大的答案;
$\operatorname{sz}[u]$ 代表遍历过 $v$ 所在的子树之前已经经过的所有子树的大小之和再乘以2,这个在遍历完毕整个子树之后会更新为该子树的大小乘以2。

我们可以发现,我们最终的答案跟遍历子树的顺序有关,于是我们考虑对其进行排序。

对于一个节点 $x$ 的两个子树 $y$ 和 $z$,我们假设先遍历 $y$ 再遍历 $z$。
这样的话,我们的答案就是 $\max(\operatorname{dp}[y]+\operatorname{sz}[u]+1,\operatorname{dp}[z]+\operatorname{sz}[u]+\operatorname{sz}[y]+2+1)$。
我们假定这个方案比交换两个子树的遍历顺序得到的答案更优。
交换两个子树的遍历顺序之后得到的答案就是 $\max(\operatorname{dp}[z]+\operatorname{sz}[u]+1,\operatorname{dp}[y]+\operatorname{sz}[u]+\operatorname{sz}[z]+2+1)$。

我们最终得到如下式子:

$\max(\operatorname{dp}[y]+\operatorname{sz}[u]+1,\operatorname{dp}[z]+\operatorname{sz}[u]+\operatorname{sz}[y]+2+1) > \max(\operatorname{dp}[z]+\operatorname{sz}[u]+1,\operatorname{dp}[y]+\operatorname{sz}[u]+\operatorname{sz}[z]+2+1)$

我们将不等式左右两边同时约掉 $\operatorname{sz}[u]+1$,得到

$\max(\operatorname{dp}[y],\operatorname{dp}[z]+\operatorname{sz}[y]+2) > \max(\operatorname{dp}[z],\operatorname{dp}[y]+\operatorname{sz}[z]+2)$

因为 $\operatorname{dp}[y] < \operatorname{dp}[y]+\operatorname{sz}[z]+2$,$\operatorname{dp}[z] < \operatorname{dp}[z]+\operatorname{sz}[y]+2$,所以一定是 $\operatorname{dp}[y]+\operatorname{sz}[z]+2$ 与 $\operatorname{dp}[z]+\operatorname{sz}[y]+2$ 之间的差值导致了答案的变化。

因此,我们可以得到

$\operatorname{dp}[y]-\operatorname{sz}[y] < \operatorname{dp}[z]-\operatorname{sz}[z]$

然后我们就可以按照这样的方法排序了。

参考代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 500010;
template<typename T>
inline T read()
{
T x = 0, f = 1; char c = getchar();
while(!isdigit(c)) { if(c == '-')f = -1; c = getchar(); }
while(isdigit(c))x = x * 10 + (c ^ 48), c = getchar();
return x * f;
}
int n;
vector<int>e[N];
int val[N];
int f[N], sz[N];
bool cmp(const int a, const int b)
{
return sz[a] - f[a] < sz[b] - f[b];
}
void dfs(int p, int fa)
{
if(p != 1)f[p] = val[p];
if(e[p].empty())return;
for(auto i : e[p])
if(i != fa)dfs(i, p);
sort(e[p].begin(), e[p].end(), cmp);
for(auto i : e[p])
{
if(i == fa)continue;
f[p] = max(f[p], f[i] + sz[p] + 1);
sz[p] += sz[i] + 2;
}
}
int main()
{
n = read<int>();
for(int i = 1; i <= n; i++)
val[i] = read<int>();
for(int i = 1; i < n; i++)
{
int u = read<int>(), v = read<int>();
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1, 0);
printf("%d\n", max(f[1], sz[1] + val[1]));
return 0;
}