浅谈换根 dp

今日一言 : Now or Never.

前言

换根 dp,用于解决根不固定的问题。所谓换根 dp,就是先求出一个根的答案,然后让子为根,将贡献由父转移到子,这样就能快速计算 n 个点分别为根时的答案

换根 dp 分为三步 :

  1. 先把树拎(ling)起来,找一点为根
  2. 计算每个子树的贡献
  3. 从根开始,由父节点向子节点传递贡献

洛谷 P3478

题意

找一个点,使其到所有节点的距离之和最大

解析

若选定一根,我们能 O(n) 求出距离之和,但每个节点都要算一次,那岂不是要 O(n2)

换根 dp 的思想可以轻松解决这个问题

先固定一个点为根,令 :

  • siz[u] 为子树的大小
  • ans[u] 为所有点到 u 的距离之和

我们可以先计算出根节点的答案,然后换根,若 u → v 进行转移,则转移方程为:

每次转移的时间复杂度 O(1),从根开始,向子转移,O(n) 即可计算出所有点为根时的 ans,最后求 max 即可

MYCODE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#include <bits/stdc++.h>
using namespace std;

void solve(){
int n;
cin >> n;
vector<vector<int>> gra(n + 1);
for(int i=1; i<n; i++){
int u, v;
cin >> u >> v;
gra[u].push_back(v);
gra[v].push_back(u);
}

vector<ll> dp(n + 1), siz(n + 1), dep(n + 1);
auto dfs1 = [&](auto&& self, int fa, int u) ->void {
siz[u] = 1;
for(int v : gra[u]) if(v != fa){
dep[v] = dep[u] + 1;
self(self, u, v);
siz[u] += siz[v];
}
};
dfs1(dfs1, 0, 1);
for(int i=1; i<=n; i++) dp[1] += dep[i];

auto dfs2 = [&](auto&& self, int fa, int u) ->void {
for(int v : gra[u]) if(v != fa){
dp[v] = dp[u] + (n-siz[v]) - siz[v];
self(self, u, v);
}
};
dfs2(dfs2, 0, 1);

ll mx = -1, idx = 0;
for(int i=1; i<=n; i++){
if(dp[i] > mx){
mx = dp[i];
idx = i;
}
}
cout << idx << '\n';
}

int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);

int t = 1; //cin >> t;
while(t--) solve();

return 0;
}

G - 2022 Guilin CCPC

题意

给定一棵树,点权为 ai。要选择两条链,可以交叉重叠,但公共部分对答案贡献为 0。求最大的 ans

解析

手玩可以发现,仅有两种情况,要么有 1 个交点,要么没有交点。

第一种情况

对于点 u,其儿子 v,找最大的 4 条链即可,链以 v 为端点,且在 v 的子树内

先固定一根,令 :

  • dp1[u] u 的子树内,以 u 为根对应的最大 4 链之和
  • p1[u][i](i = 0, 1, 2, 3) u 的儿子中,dp1i 大的那个儿子 v

先进行树上 dp 求出 dp1p1。观察某一个非根节点 u,可以发现 dp1[u] 的情况中少考虑了一条来自父亲的最长链,我们可以从上往下走时维护这条链,然后从这 5 条链里找最大的 4 条即可

第二种情况

可以将树拆成两个,分别维护最长链

同样固定一根,令 :

  • dp[u] u 的子树中的最长链
  • p2[u][i](i = 0, 1) u 的儿子中,dp2i 的那个儿子 v

对于某一个非根节点 v,我们有了 dp2[v],还需要在 v 的子树外再找一条最长链。可能是一条链来自父亲,然后和别的链拼接。或者是链不经过点 u,在 u 的其他儿子的子树内。几种情况取 max 即可,具体实现参考代码

MYCODE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <bits/stdc++.h>
using namespace std;
using ll = long long;

void solve(){
int n;
cin >> n;
vector<int> a(n + 1);
for(int i=1; i<=n; i++){
cin >> a[i];
}
vector<vector<int>> gra(n + 1);
for(int i=1; i<n; i++){
int u, v;
cin >> u >> v;
gra[u].push_back(v);
gra[v].push_back(u);
}

vector<ll> dp1(n + 1), dp2 = dp1;
vector<vector<int>> p1(n + 1, vector<int>(4)), p2(n + 1, vector<int>(2));
// 求出 dp1, dp2, p1, p2
auto dfs1 = [&](auto&& self, int fa, int u) ->void {
dp1[u] = dp2[u] = a[u];
for(int v : gra[u]) if(v != fa){
self(self, u, v);
dp1[u] = max(dp1[u], dp1[v] + a[u]);
dp2[u] = max(dp2[u], dp2[v]);
int x = v;
for(int i=0; i<4; i++){
if(dp1[x] > dp1[p1[u][i]]){
swap(x, p1[u][i]);
}
}

x = v;
for(int i=0; i<2; i++){
if(dp2[x] > dp2[p2[u][i]]){
swap(x, p2[u][i]);
}
}
}
dp2[u] = max(dp2[u], dp1[p1[u][0]] + dp1[p1[u][1]] + a[u]);
};
dfs1(dfs1, 0, 1);

ll ans = dp1[p1[1][0]] + dp1[p1[1][1]] + dp1[p1[1][2]] + dp1[p1[1][3]];
// 换根 dp // d1 为 u 为端点,父方向最长链 // d2 为 u 的父方向的最长链
auto dfs2 = [&](auto&& self, int fa, int u, ll d1, ll d2) ->void {
for(int v : gra[u]) if(v != fa){
ll nd1 = 0, nd2 = 0;
// 第一种情况
if(v == p1[u][0]){
nd1 = max(d1, dp1[p1[u][1]]) + a[u];
}
else{
nd1 = max(d1, dp1[p1[u][0]]) + a[u];
}

ans = max(ans, dp1[p1[v][0]] + dp1[p1[v][1]] + dp1[p1[v][2]] + max(nd1, dp1[p1[v][3]]));

// 第二种情况
if(v == p1[u][0]){
nd2 = max(d1 + dp1[p1[u][1]], dp1[p1[u][1]] + dp1[p1[u][2]]) + a[u];
}
else if(v == p1[u][1]){
nd2 = max(d1 + dp1[p1[u][0]], dp1[p1[u][0]] + dp1[p1[u][2]]) + a[u];
}
else{
nd2 = max(d1 + dp1[p1[u][0]], dp1[p1[u][0]] + dp1[p1[u][1]]) + a[u];
}
if(v == p2[u][0]) nd2 = max(nd2, dp2[p2[u][1]]);
else nd2 = max(nd2, dp2[p2[u][0]]);
nd2 = max(nd2, d2);

ans = max(ans, dp2[v] + nd2);

self(self, u, v, nd1, nd2);
}
};
dfs2(dfs2, 0, 1, 0, 0);

cout << ans << '\n';
}

int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);

int t = 1; //cin >> t;
while(t--) solve();

return 0;
}