mex 权值和化为连通块计数

牛客寒假集训营 4 D-数字积木 题解

题意

给定 n 个节点的树,点权构成一个 [0, n) 的排列,求所有连通块的 mex 权值和

思路

将 mex 权值和 转化为 连通块计数。具体来说,分别统计包括 {0},{0, 1} … {0,…,n-1} 的连通块个数,求这些个数之和就等价于求所有连通块的 mex 权值和。

举个例子,若我们有一个 mex=4 的连通块 S,包括 {0, 1, 2, 3, (mex=4), …} 这些数字。当我们统计包括 {0} 的连通块的个数时,S 会贡献 1,然后是统计 {0, 1},{0, 1, 2},{0, 1, 2, 3},S 也都会贡献 1,一共贡献了 4,正好和其 mex 权值相同

dp 进行连通块计数

令 dp[i] 为 i 必选,i 的子树的连通块个数,每个子节点可选可不选,转移方程为

更新集合贡献

若上一次的必选集合是 {0, 1, 2 … x},贡献是 lst

现在我们要将 x+1 加入集合,我们从 x+1 开始,不断的向上直到遇到已经必选的点,更新贡献

这样就将经过的点都变成了必选点,之后更新 ans,ans += lst

时间复杂度 O(nlog(mod))

注意:要特别处理 (dp[u]+1) % mo = 0 的情况

MYCODE

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using pii = pair<int, int>;

constexpr int inf = 0x3f3f3f3f;
constexpr int mo = 1e9 + 7;

ll qpow(ll a, ll x){
ll ans = 1;
while(x){
if(x & 1) ans *= a;
a *= a;
x >>= 1;
ans %= mo;
a %= mo;
}
return ans;
}

ll inv(int x){
return qpow(x, mo-2);
}

// 支持除 0,处理 inv(0) 的情况
struct ZInt {
ll prod = 1;
int cnt = 0;
void operator *= (ll x){
x %= mo;
if(x == 0) cnt++;
else prod*=x, prod%=mo;
}
void operator /= (ll x){
x %= mo;
if(x == 0) cnt--;
else prod*=inv(x), prod%=mo;
}
ll get(){ return (cnt? 0:prod); }
};

void solve(){
int n;
cin >> n;
vector<int> a(n + 1), id(n);
for(int i=1; i<=n; i++){
cin >> a[i];
id[a[i]] = i;
}

ll ans = 0;
vector<vector<int>> gra(n + 1);
for(int i=0; i<n-1; i++){
int u, v;
cin >> u >> v;
gra[u].push_back(v);
gra[v].push_back(u);
}

vector<ll> dp(n + 1), p(n + 1);
auto dfs = [&](auto&& self, int fa, int u) -> void {
dp[u] = 1;
p[u] = fa;
for(int v : gra[u]) if(v != fa){
self(self, u, v);
dp[u] *= dp[v] + 1;
dp[u] %= mo;
}
};
dfs(dfs, 0, id[0]);

ZInt lst;
lst *= dp[id[0]] + 1;
vector<bool> vis(n + 1);
for(int i=0; i<n; i++){
int x = id[i];
while(!vis[a[x]]){
vis[a[x]] = true;
lst /= dp[x]+1;
lst *= dp[x];
// lst *= inv(dp[x]+1) * dp[x] % mo; // dp[x]+1 = 1e9+7 !
// lst %= mo;
x = p[x];
}
ans += lst.get();
ans %= mo;
}
cout << ans << '\n';
}

int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);

int t = 1; //cin >> t;
while(t--) solve();

return 0;
}