0%

问题描述

我们知道,求任意图的最大独立集是一类 NP 完全问题,目前还没有准确的多项式算法,但是有许多多项式复杂度的近似算法。

例如,小 C 常用的一种算法是:

  1. 对于一个 $n$ 个点的无向图,先等概率随机一个 $1$ 到 $n$ 的排列 $p[1\ldots n]$。
  2. 维护答案集合 $S$ ,一开始 $S$ 为空集,之后按照 $i=1\ldots n$ 的顺序,检查 ${p[i]}\cup S$ 是否是一个独立集,如果是的话就令 $S={p[i]}\cup S$。
  3. 最后得到一个独立集 $S$ 作为答案。

小 C 现在想知道,对于给定的一张图,这个算法的正确率,输出答案对 $998244353$ 取模

输入

第一行两个非负整数 $n,m$ 表示给定的图的点数和边数。

接下来 $m$ 行,每行有两个正整数 $(u,v) (u\neq v)$ 描述这张图的一条无向边。

输出

输出正确率,答案对 $998244353$ 取模。

思路

这题比较巧妙的地方就是设计状态。考虑比较朴素的一种方法, $0$ 表示一个点未被选,$1$ 表示一个点选了但没在集合中,$2$ 表示一个点在集合中,这样的状态数就是 $n^3$,我们需要考虑的就是如何优化这种状态。

这里要跳出按照排列的顺序选点这个思维定式。只要选了点 $i$ 加入独立集中,那么与 $i$ 相邻的点在它之后任意时候选都不会产生影响,所以假设有 $s1$ 个点没有考虑,这其中有 $s2$ 个点是与 $i$ 相邻的,那么我们考虑这 $s2$ 个点在排列中的情况数,不难发现这个就是 $A_{s1}^{s2}$。注意我们这里说的是「考虑」而不是选取,也就是说先预先给 $s2$ 个点排好位置,而不是按照排列的顺序依次选了这些点。

所以我们可以这样设计状态,设 $f_{S, i}$ 为考虑了集合 $S$ 且其中独立集大小为 $i$ 的方案数。按照上文提到的考虑方法,可以知道只要是不属于集合 $S$ 的点一定可以加入 $S$ 中的独立集,因为与独立集相邻的点已经在 $S$ 中了。设已经考虑了点集 $S$ ,此时要加入点 $k$,与$k$ 相邻的点集为 $e_k$,还有 $s1$ 个点没有考虑,这其中有 $s2$ 个点是与 $k$ 相邻的,有转移方程

其中 $k\notin S$,$s1=n-|S|-1$,$s2=|e_k\bigcap(e_k\bigcup S)|$。这样转移总复杂度就是 $O(n^22^n)$ 了。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN 25
#define MAXM (1<<20)+5
#define INF 0x3f3f3f3f
#define p 998244353
#define rint register int
#define LL long long
#define LD long double
using namespace std;

int n, m, e[MAXN], f[MAXM][MAXN], bit[MAXM], fac[MAXN], inv[MAXN];

void init()
{
fac[0]=fac[1]=inv[0]=inv[1]=1;
for(rint i=0; i<(1<<n); ++i) bit[i]=bit[i>>1]+(i&1);
for(rint i=2; i<=n; ++i) inv[i]=1LL*inv[p%i]*(p-p/i)%p;
for(rint i=2; i<=n; ++i)
{
inv[i]=1LL*inv[i]*inv[i-1]%p;
fac[i]=1LL*fac[i-1]*i%p;
}
}

int A(int x, int y)
{
return 1LL*fac[x]*inv[x-y]%p;
}

int main()
{
scanf("%d%d", &n, &m);
init();
for(rint i=1, x, y; i<=m; ++i)
{
scanf("%d%d", &x, &y);
x--, y--;
e[x]|=(1<<y);
e[y]|=(1<<x);
}
for(rint i=0; i<n; ++i) e[i]|=(1<<i);
f[0][0]=1;
for(rint sta=0; sta<(1<<n); ++sta)
for(rint i=0; i<n; ++i)
{
if(!f[sta][i]) continue;
for(rint j=0; j<n; ++j)
if(!((sta>>j)&1))
{
int nxt=sta|e[j], s1=n-bit[sta]-1, s2=bit[e[j]^(e[j]&sta)]-1;
f[nxt][i+1]=(f[nxt][i+1]+1LL*f[sta][i]*A(s1, s2)%p)%p;
}
}
for(rint i=n; i>=0; i--)
if(f[(1<<n)-1][i])
{
printf("%lld\n", 1LL*f[(1<<n)-1][i]*inv[n]%p);
return 0;
}
return 0;
}

问题描述

九条可怜在玩一个很好玩的策略游戏:Slay the Spire,一开始九条可怜的卡组里有 $2n$ 张牌,每张牌上都写着一个数字 $w_i$,一共有两种类型的牌,每种类型各 $n$ 张:

攻击牌:打出后对对方造成等于牌上的数字的伤害。

强化牌:打出后,假设该强化牌上的数字为 $x$,则其他剩下的攻击牌的数字都会乘上 $x$。保证强化牌上的数字都大于 $1$。

现在九条可怜会等概率随机从卡组中抽出 $m$ 张牌,由于费用限制,九条可怜最多打出 $k$ 张牌,假设九条可怜永远都会采取能造成最多伤害的策略,求她期望造成多少伤害。

假设答案为 $ans$,你只需要输出 $ans * \frac{(2n)!}{m!(2n-m)!} \mod 998244353 $

输入

第一行一个正整数 $T$ 表示数据组数 接下来对于每组数据: 第一行三个正整数 $n,m,k$ 第二行 $n$ 个正整数 $w_i$,表示每张强化牌上的数值 第三行 $n$ 个正整数 $w_i$,表示每张攻击牌上的数值

输出

输出 $T$ 行,每行一个非负整数表示每组数据的答案。

思路

首先题目要求的其实是所有 ${2n \choose m}$ 种情况造成的伤害之和,是一个计数问题。

有一个很巧妙的条件是「强化牌上的数字是大于 1 的整数 」我们可以据此得出最优策略是:如果强化牌数量大于 $i$,那么就先用最大的 $k-1$ 张强化牌,再用最大的 $1$ 张攻击牌;否则就是先用掉所有强化牌,再用最大的 $k-i$ 张攻击牌。

先将两种牌从大到小排序,设计状态 $f_{i,j}$ 为用了$i$ 张牌最小的牌为 $w_j$ 所有这种情况的强化倍数之和;$g_{i,j}$ 为用了$i$ 张牌最小的牌为 $w[j]$ 所有这种情况的总伤害之和。它们的转移方程分别为 但是这样还不够,我们还要考虑每一个状态所对应的情况数。设 $F_{i,j}$ 为抽到了 $i$ 张强化牌,并使用了其中 $j$ 张的情况的总强化倍数乘对应情况数;$G_{i,j}$ 为抽到了 $i$ 张攻击牌,并使用了其中 $j$ 张的情况的总伤害乘对应情况数。$F_{i,j},G_{i,j}$ 满足式子 这时候我们最初提到的最优策略就要用到了,我们先枚举 $m$ 张牌的类型,然后根据最优策略算出对应的贡献。因此答案就是

最后是实现的细节,我们先预处理 $f_{i,j}$ 和 $g_{i,j}$,利用前缀和可以做到 $O(n^2)$。然后我们可以在计算答案时对于每一个 $F_{i,j},G_{i,j}$ 用 $O(n)$ 计算出来,这样总复杂度就是 $O(n^2)$ 的。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN 3005
#define INF 0x3f3f3f3f
#define p 998244353
#define rint register int
#define LL long long
#define LD long double
using namespace std;

int T, n, m, k, ans, w[MAXN], v[MAXN], f[MAXN][MAXN], g[MAXN][MAXN];
int inv[MAXN], fac[MAXN];

bool CMP(int x, int y)
{
return x>y;
}

void init()
{
inv[0]=inv[1]=fac[0]=fac[1]=1;
for(rint i=2; i<=3000; ++i) inv[i]=1LL*inv[p%i]*(p-p/i)%p;
for(rint i=2; i<=3000; ++i)
{
inv[i]=1LL*inv[i-1]*inv[i]%p;
fac[i]=1LL*fac[i-1]*i%p;
}
}

int comb(LL x, LL y)
{
if(y>x) return 0;
return 1LL*fac[x]*inv[y]%p*inv[x-y]%p;
}

int calf(int x, int y)
{
int sum=0;
if(y==0) return comb(n, x);
for(rint i=1; i<=n; ++i)
sum=(sum+1LL*f[y][i]*comb(n-i, x-y))%p;
return sum;
}

int calg(int x, int y)
{
int sum=0;
for(rint i=1; i<=n; ++i)
sum=(sum+1LL*g[y][i]*comb(n-i, x-y))%p;
return sum;
}

int main()
{
init();
scanf("%d", &T);
while(T--)
{
ans=0;
scanf("%d%d%d", &n, &m, &k);
for(rint i=1; i<=n; ++i) scanf("%d", &w[i]);
for(rint i=1; i<=n; ++i) scanf("%d", &v[i]);

sort(w+1, w+n+1, CMP);
sort(v+1, v+n+1, CMP);
for(rint i=1; i<=n; ++i)
g[1][i]=v[i], f[1][i]=w[i];
for(rint i=2; i<=n; ++i)
{
int sum1=0, sum2=0;
for(rint j=1; j<=n; ++j)
{
f[i][j]=1LL*w[j]*sum1%p;
g[i][j]=(1LL*comb(j-1, i-1)*v[j]+sum2)%p;
//printf("%d %d %d %d %d!!!\n", i, j, sum2, w[j]);
sum1=(sum1+f[i-1][j])%p;
sum2=(sum2+g[i-1][j])%p;
}
}
//for(rint i=1; i<=n; ++i)
// for(rint j=1; j<=n; ++j) printf("%d %d %d %d!!!\n", i, j, f[i][j], g[i][j]);
for(rint i=0; i<m; ++i)
{
int a=calf(i, min(i, k-1));
int b=calg(m-i, max(1, k-i));
ans=(ans+1LL*a*b)%p;
//printf("%d %d %d!!\n", ans, a, b);
}
printf("%d\n", ans);
}
return 0;
}

1月6日 Educational DP Contest 题目链接:https://atcoder.jp/contests/dp

大部分水题就不讲了,记录几道有趣的题目。

J - Sushi

思路: 期望要倒着推!!!

设 $f_{i,j,k}$ 为有 $i$ 个 $3$ 个的寿司碟子,有 $j$ 个 $2$ 个的寿司碟子,有 $k$ 个 $1$ 个的寿司碟子的期望,那么就有 $n-i-j-k$ 个空着的碟子。然后倒着推期望!!!!

这样我们就得到了递推的关系。因为转移的顺序问题,我们不能简单地按照 $i, j, k$ 的顺序 DP。但发现转移是从寿司总数少的状态到寿司总数多到状态,因此我们要按照寿司的总数进行 DP。总时间复杂度 $O(n^3)$。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN 305
#define INF 0x3f3f3f3f
#define p 1000000007
#define rint register int
#define LL long long
#define LD long double
using namespace std;

int n, sum, a[MAXN], num[4];
LD f[MAXN][MAXN][MAXN];

int main()
{
scanf("%d", &n);
for(rint i=1; i<=n; ++i)
{
scanf("%d", &a[i]);
num[a[i]]++; sum+=a[i];
}
for(rint x=1; x<=sum; ++x)
{
for(rint i=0; i<=x/3; ++i)
{
for(rint j=0; j<=(x-i*3)/2; ++j)
{
int k=x-i*3-j*2;
if(i+j+k>n) continue;
//printf("%d %d %d!!!\n", i, j, k);
f[i][j][k]=n;
if(i>0) f[i][j][k]+=i*f[i-1][j+1][k];
if(j>0) f[i][j][k]+=j*f[i][j-1][k+1];
if(k>0) f[i][j][k]+=k*f[i][j][k-1];
f[i][j][k]/=(i+j+k);
//printf("%d %d %d %Lf!!!\n", i, j, k, f[i][j][k]);
}
}
}
printf("%.10Lf\n", f[num[3]][num[2]][num[1]]);
return 0;
}

O - Matching

思路: 一个简单的容斥计数,感觉和 DP 关系不大。

枚举女生集合 $S$,计算所有男生和集合中女生配对的方案数,记为 $f[S]$。既然题目要求一一配对的方案,那么如果 $n-|S|$ 为奇数,那么答案就应该减去 $f[S]$,否则就加上 $f[S]$。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN 21
#define INF 0x3f3f3f3f
#define p 1000000007
#define rint register int
#define LL long long
#define LD long double
using namespace std;

int n, ans, s[MAXN];

int main()
{
scanf("%d", &n);
for(rint i=0; i<n; ++i)
for(rint j=0, x; j<n; ++j)
{
scanf("%d", &x);
s[i]|=(x<<j);
}
for(rint sta=0; sta<(1<<n); ++sta)
{
int siz=0, val=1;
for(rint i=0; i<n; ++i)
siz+=((sta>>i)&1);
for(rint i=0; i<n; ++i)
{
int temp=s[i]&sta, num=0;
for(rint j=0; j<n; ++j)
num+=((temp>>j)&1);
val=1LL*val*num%p;
}
if((n-siz)&1) ans=(ans+p-val)%p;
else ans=(ans+val)%p;
}
printf("%d\n", ans);
return 0;
}

U - Grouping

思路: 一个简单的子集 DP。

枚举所有集合 $S$,记 $f[S]$ 为选取 $S$ 集合的最大答案。先算出把 $S$ 集合作为一整组的分数作为 $f[S]$ 的初值,然后枚举 $S$ 的子集 $T$,得到转移方程 $f[S]=max(f[S], f[T]+f[S xor T])$,直接转移就行。总复杂度 $O(3^nn^2)$ 代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN 20
#define INF 0x3f3f3f3f
#define p 1000000007
#define rint register int
#define LL long long
#define LD long double
using namespace std;

int n, mp[MAXN][MAXN];
LL f[1 << 17];

int main()
{
scanf("%d", &n);
for(rint i = 0; i < n; ++i)
for(rint j = 0; j < n; ++j)
scanf("%d", &mp[i][j]);

for(rint sta = 0; sta < (1 << n); ++sta)
{
for(rint i = 0; i < n; ++i)
if((sta >> i) & 1)
for(rint j = i + 1; j < n; ++j)
if((sta >> j)&1) f[sta] += mp[i][j];
for(rint i = sta; i; i = (i-1) & sta)
{
int j = sta ^ i;
f[sta] = max(f[sta], f[i] + f[j]);
//printf("%d %d %d %lld %lld!!\n", sta, i, j, f[i], f[j]);
}
}
printf("%lld\n", f[(1 << n) - 1]);
}

V - Subtree

思路: 一个简单的树形 DP 加换根操作,但是有一个很有趣的细节。

设 $f_i$ 为以 $i$ 为根的子树的方案数,那么就有 DP 方程:$f_i=\prod (f_{to}+1)$。考虑换根操作,设 $val_i$ 为节点 $i$ 的父亲所在子树的贡献,那么节点 $i$ 的答案就是 $(val_i+1)*f[i]$。$val$ 的求法一般有两种,第一种是通过式子:$val_{to}=(val_x+1)*f_x/f_{to}$ 得到。这种方法看似很简单,但是其中有除法,对于模意义下的除法表面上看可以用扩展欧几里得求逆元实现。但是,有一些数是没有逆元的,也就是说对于任意除数的除法我们不能够简单地实现。

这时候我们就要考虑第二种求 $val$ 的方法,那就是给子树一个固定的顺序,然后求出子树的前缀和后缀贡献,从而推出 $val_{to}$,虽然这种方法更为复杂,但可以很好地避免除法。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN 100005
#define INF 0x3f3f3f3f
#define rint register int
#define LL long long
#define LD long double
using namespace std;

int n, m, cnt, head[MAXN], f[MAXN], ans[MAXN];
vector<int> vec[MAXN], l[MAXN], r[MAXN];

void dfs1(int x, int fa)
{
f[x] = 1;
for(rint i = 0; i < vec[x].size(); ++i)
{
int to = vec[x][i];
if(to == fa) continue;
dfs1(to, x);
f[x]= 1LL * f[x] * (f[to]+1) % m;
}
}

void dfs2(int x, int fa, int val)
{
ans[x] = 1LL * f[x] * (val + 1) % m;
l[x].resize(vec[x].size());
r[x].resize(vec[x].size());
for(rint i = 0; i < vec[x].size(); ++i)
{
if(i == 0) l[x][i] = 1;
else
{
if(vec[x][i - 1] == fa) l[x][i] = l[x][i - 1];
else l[x][i] = 1LL * l[x][i - 1] * (f[vec[x][i - 1]] + 1) % m;
}
}
for(rint i=vec[x].size()-1; i >= 0; --i)
{
if(i == vec[x].size() - 1) r[x][i] = 1;
else
{
if(vec[x][i + 1] == fa) r[x][i] = r[x][i + 1];
else r[x][i] = 1LL * r[x][i + 1] * (f[vec[x][i + 1]] + 1) % m;
}
}
//printf("%d %d %d!!\n", x, fa, val);
for(rint i = 0; i < vec[x].size(); ++i)
{
int to = vec[x][i];
if(to == fa) continue;
//printf("%d %d %d!!!\n", to, l[x][i], r[x][i]);
dfs2(to, x, 1LL * (val + 1) * l[x][i] % m * r[x][i] % m);
}
}

int main()
{
scanf("%d%d", &n, &m);
for(rint i = 1, x, y; i < n; ++i)
{
scanf("%d%d", &x, &y);
vec[x].push_back(y);
vec[y].push_back(x);
}
dfs1(1, 0);
dfs2(1, 0, 0);
for(rint i = 1; i <= n; ++i) printf("%d\n", ans[i]);
return 0;
}

W - Intervals

思路:

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN 200005
#define INF 0x3f3f3f3f
#define rint register int
#define LL long long
#define LD long double
#define pii pair<int, int>
#define ls (root<<1)
#define rs (root<<1|1)
#define mid ((l+r)>>1)
using namespace std;

int n, m;
LL t[MAXN*4], tag[MAXN*4];
vector<pii> vec[MAXN];

void up(int root)
{
t[root]=max(t[ls], t[rs])+tag[root];
}

void update(int root, int l, int r, int x, int y, LL k)
{
//printf("%d %d %d %d!!!\n", root, l, r, k);
if(l>y || r<x) return;
if(l>=x && r<=y)
{
t[root]+=k, tag[root]+=k;
return ;
}
update(ls, l, mid, x, y, k);
update(rs, mid+1, r, x, y, k);
up(root);
}

int main()
{
scanf("%d%d", &n, &m);
for(rint i=1; i<=m; ++i)
{
int l, r, a;
scanf("%d%d%d", &l, &r, &a);
vec[r].push_back(make_pair(l, a));
}
for(rint i=1; i<=n; ++i)
{
update(1, 1, n, i, i, t[1]);
for(rint j=0; j<vec[i].size(); ++j)
{
pii temp=vec[i][j];
update(1, 1, n, temp.first, i, temp.second);
}
}
printf("%lld\n", max(0LL, t[1]));
}

问题描述

设 $T$ 为一棵有根树,我们做如下的定义:

• 设 $a$ 和 $b$ 为 $T$ 中的两个不同节点。如果 $a$ 是 $b$ 的祖先,那么称「$a$ 比 $b$ 不知道高明到哪里去了」。

• 设 $a$ 和 $b$ 为 $T$ 中的两个不同节点。如果 $a$ 与 $b$ 在树上的距离不超过某个给定常数 $x$,那么称「$a$ 与 $b$ 谈笑风生」。

给定一棵 $n$ 个节点的有根树 $T$,节点的编号为 $1∼n$,根节点为 $1$ 号节点。你需要回答 $q$ 个询问,询问给定两个整数 $p $ 和 $k$,问有多少个有序三元组 $(a, b, c)$ 满足:

• $a$,$b$ 和 $c$ 为 $T$ 中三个不同的点,且 $a$ 为 $p$ 号节点;

• $a$ 和 $b$ 都比 $c$ 不知道高明到哪里去了;

• $a$ 和 $b$ 谈笑风生。这里谈笑风生中的常数为给定的 $k$。

输入

输入文件的第一行含有两个正整数 $n$ 和 $q$,分别代表有根树的点数与询问的个数。

接下来 $n−1$ 行,每行描述一条树上的边。每行含有两个整数 $u$ 和 $v$,代表在节点 $u$ 和 $v$ 之间有一条边。

接下来 $q$ 行,每行描述一个操作。第 $i$ 行含有两个整数,分别表示第 $i$ 个询问的 $p$ 和 $k$。

输出

输出 $q$ 行,每行对应一个询问,代表询问的答案。

思路

特别经典的一道题,大部分树上的数据结构都能够在这题使用。除了下面讲到都主席树和线段树合并的做法之外,还可以用树上启发式合并,长链剖分等做法去做。

我们将答案分成两个部分。第一部分是 $b$ 是 $a$ 的祖先,答案就是:

第二部分是 $a$ 是 $b$ 的祖先,答案就是:

于是我们不难想到以 $dep[x]$ 为关键字,$siz[x]-1$ 为权值建立权值线段树。对于每一个询问的答案就是在对应线段树上询问区间 $[dep[x]+1, dep[x]+k]$。实现时就可以用主席树或线段树合并来节省空间。

代码

主席树:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN 300005
#define INF 0x3f3f3f3f
#define rint register int
#define LL long long
#define LD long double
using namespace std;

int n, q, cnt, id, tot, head[MAXN], dep[MAXN], siz[MAXN], in[MAXN], out[MAXN], ver[MAXN], root[MAXN];

struct Edge {int next, to;} edge[MAXN*2];
struct Node {int ls, rs; LL val;} t[MAXN*40];

void addedge(int from, int to)
{
edge[++cnt].next=head[from];
edge[cnt].to=to;
head[from]=cnt;
}

void dfs(int x, int fa)
{
siz[x]=1; in[x]=++id; ver[id]=x;
for(rint i=head[x]; i; i=edge[i].next)
{
int to=edge[i].to;
if(to==fa) continue;
dep[to]=dep[x]+1;
dfs(to, x);
siz[x]+=siz[to];
}
out[x]=id;
}

void update(int &rt1, int rt2, int l, int r, int x, int k)
{
rt1=++tot; t[rt1]=t[rt2]; t[rt1].val+=k;
if(l==r) return;
int mid=(l+r)>>1;
if(x<=mid) update(t[rt1].ls, t[rt2].ls, l, mid, x, k);
else update(t[rt1].rs, t[rt2].rs, mid+1, r, x, k);
}

LL query(int rt, int l, int r, int x, int y)
{
if(l>y || r<x) return 0;
if(l>=x && r<=y) return t[rt].val;
int mid=(l+r)>>1;
return query(t[rt].ls, l, mid, x, y)+query(t[rt].rs, mid+1, r, x, y);
}

int main()
{
scanf("%d%d", &n, &q);
for(rint i=1, x, y; i<n; ++i)
{
scanf("%d%d", &x, &y);
addedge(x, y);
addedge(y, x);
}
dfs(1, 0);
for(rint i=1; i<=n; ++i)
update(root[i], root[i-1], 1, n, dep[ver[i]], siz[ver[i]]-1);
while(q--)
{
int x, k;
scanf("%d%d", &x, &k);
LL a=1LL*min(dep[x], k)*(siz[x]-1);
LL b=query(root[in[x]], 1, n, dep[x], dep[x]+k);
LL c=query(root[out[x]], 1, n, dep[x], dep[x]+k);
printf("%lld\n", a+c-b);
}
return 0;
}

线段树合并:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#define MAXN 300005
#define INF 0x3f3f3f3f
#define rint register int
#define LL long long
#define LD long double
using namespace std;

int n, q, cnt, tot, head[MAXN], dep[MAXN], siz[MAXN], root[MAXN];
LL ans[MAXN];

struct Edge {int next, to;} edge[MAXN*2];
struct Node {int ls, rs; LL val;} t[MAXN*40];

void addedge(int from, int to)
{
edge[++cnt].next=head[from];
edge[cnt].to=to;
head[from]=cnt;
}

void update(int &rt, int l, int r, int x, int k)
{
if(!rt) rt=++tot; t[rt].val+=k;
if(l==r) return;
int mid=(l+r)>>1;
if(x<=mid) update(t[rt].ls, l, mid, x, k);
else update(t[rt].rs, mid+1, r, x, k);
}

int merge(int x, int y, int l, int r)
{
if((!x) || (!y)) return x+y;
int mid=(l+r)>>1, rt=++tot;
t[rt].val=t[x].val+t[y].val;
t[rt].ls=merge(t[x].ls, t[y].ls, l, mid);
t[rt].rs=merge(t[x].rs, t[y].rs, mid+1, r);
return rt;
}

LL query(int rt, int l, int r, int x, int y)
{
if(l>y || r<x || !rt) return 0;
if(l>=x && r<=y) return t[rt].val;
int mid=(l+r)>>1;
return query(t[rt].ls, l, mid, x, y)+query(t[rt].rs, mid+1, r, x, y);
}

void dfs(int x, int fa)
{
siz[x]=1; dep[x]=dep[fa]+1;
for(rint i=head[x]; i; i=edge[i].next)
{
int to=edge[i].to;
if(to==fa) continue;
dfs(to, x);
siz[x]+=siz[to];
root[x]=merge(root[x], root[to], 1, n);
}
update(root[x], 1, n, dep[x], siz[x]-1);
}

int main()
{
scanf("%d%d", &n, &q);
for(rint i=1, x, y; i<n; ++i)
{
scanf("%d%d", &x, &y);
addedge(x, y);
addedge(y, x);
}
dfs(1, 0);
for(rint i=1; i<=q; ++i)
{
int x, k;
scanf("%d%d", &x, &k);
LL a=1LL*min(dep[x]-1, k)*(siz[x]-1);
LL b=query(root[x], 1, n, dep[x]+1, dep[x]+k);
printf("%lld\n", a+b);
}
return 0;
}