采蘑菇

Time Limit: 20 Sec Memory Limit: 256 MB

Description

img

Input

img

Output

img

Sample Input

5
  1 2 3 2 3
  1 2
  1 3
  2 4
  2 5

Sample Output

10
  9
  12
  9
  11

HINT

img

Main idea

询问从以每个点为起始点时,各条路径上的颜色种类的和。

Solution

我们看到题目,立马想到了O(n^2)的做法,然后从这个做法研究一下本质,我们确定了可以以点分治作为框架。

我们先用点分治来确定一个center(重心)。然后计算跟这个center有关的路径。设现在要统计的是经过center,对x提供贡献的路径。

我们先记录一个记录Sum[x]表示1~i-1子树中 颜色x 第一次出现的位置的那个点 的子树和,然后我们就利用这个Sum来解题。

我们显然可以分两种情况来讨论:

(1)统计center->x出现颜色的贡献
    显然,这时候,对于center->x这一段,直接像O(n^2)做法那样记录一个color表示到目前为止出现的颜色个数,然后加一下即可。再记录一个record表示当前可有的贡献和,一旦出现过一个颜色,那么这个颜色在1~i-1子树上出现第一次以下的点,对于x就不再提供贡献了,record减去Sum[这个颜色],然后这样深搜往下计算即可。

(2)统计center->x没出现过的颜色的贡献
    显然,对于center->x上没出现过的颜色,直接往下深搜,一开始为record为**(All - Sum[center])**,一旦出现了一个颜色,record则减去这个Sum。同样表示不再提供贡献即可。

我们这样做就可以求出每个子树前缀对于其的贡献了,倒着再做一边即可求出全部的贡献。统计x的时候,顺便统计一下center。可以满足效率,成功AC这道题。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#include<bits/stdc++.h>
using namespace std;

const int ONE = 600005;
const int INF = 214783640;
const int MOD = 1e9+7;

int n,x,y;
int Val[ONE];
int next[ONE],first[ONE],go[ONE],tot;
int vis[ONE];
int Ans[ONE],Sum[ONE];
int All;


int get()
{
int res,Q=1; char c;
while( (c=getchar())<48 || c>57)
if(c=='-')Q=-1;
if(Q) res=c-48;
while((c=getchar())>=48 && c<=57)
res=res*10+c-48;
return res*Q;
}

void Add(int u,int v)
{
next[++tot]=first[u]; first[u]=tot; go[tot]=v;
next[++tot]=first[v]; first[v]=tot; go[tot]=u;
}

namespace Point
{
int center;
int Stack[ONE],top;
int total,Max,center_vis[ONE];
int num,V[ONE];

struct power
{
int size,maxx;
}S[ONE];

void Getsize(int u,int father)
{
S[u].size=1;
S[u].maxx=0;
for(int e=first[u];e;e=next[e])
{
int v=go[e];
if(v==father || center_vis[v]) continue;
Getsize(v,u);
S[u].size += S[v].size;
S[u].maxx = max(S[u].maxx,S[v].size);
}
}

void Getcenter(int u,int father,int total)
{
S[u].maxx = max(S[u].maxx,total-S[u].size);
if(S[u].maxx < Max)
{
Max = S[u].maxx;
center = u;
}

for(int e=first[u];e;e=next[e])
{
int v=go[e];
if(v==father || center_vis[v]) continue;
Getcenter(v,u,total);
}
}

void Ad_sum(int u,int father)
{
if(!vis[Val[u]])
{
Stack[++top] = Val[u];
All += S[u].size; Sum[Val[u]] += S[u].size;
}
vis[Val[u]]++;
for(int e=first[u];e;e=next[e])
{
int v=go[e];
if(v==father || center_vis[v]) continue;
Ad_sum(v,u);
}
vis[Val[u]]--;
}

void Calc_in(int u,int father,int center,int Size,int f_time,int record)
{
if(!vis[Val[u]]) f_time++, record += Size, record -= Sum[Val[u]];
Ans[u] += record; Ans[center]+=f_time;
Ans[u] += f_time; vis[Val[u]] ++;
for(int e=first[u];e;e=next[e])
{
int v=go[e];
if(v==father || center_vis[v]) continue;
Calc_in(v,u,center,Size,f_time,record);
}
vis[Val[u]] --;
}

void Calc_not(int u,int father,int record)
{
if(!vis[Val[u]]) record -= Sum[ Val[u] ];
Ans[u] += record; vis[Val[u]] ++;
for(int e=first[u];e;e=next[e])
{
int v=go[e];
if(v==father || center_vis[v]) continue;
Calc_not(v,u,record);
}
vis[Val[u]] --;
}

void Dfs(int u)
{
Max = n;
Getsize(u,0);
Getcenter(u,0,S[u].size);
Getsize(center,0);
center_vis[center] = 1;

int num=0; for(int e=first[center];e;e=next[e]) if(!center_vis[go[e]]) V[++num]=go[e];

for(int i=1;i<=num;i++)
{
int v=V[i];
int Size = S[center].size - S[v].size - 1;
vis[Val[center]] = 1;
Calc_in(v,center,center, Size,1,All - Sum[Val[center]] + Size);
vis[Val[center]] = 0;
Ad_sum(v,center);
}
while(top) Sum[Stack[top--]]=0; All=0;

for(int i=num;i>=1;i--)
{
int v=V[i];
vis[Val[center]] = 1;
Calc_not(v,center, All-Sum[Val[center]]);
vis[Val[center]] = 0;
Ad_sum(v,center);
}

while(top) Sum[Stack[top--]]=0; All=0;
for(int e=first[center];e;e=next[e])
{
int v=go[e];
if(center_vis[v]) continue;
Dfs(v);
}
}

}

int main()
{
n=get();
for(int i=1;i<=n;i++) Val[i]=get();

for(int i=1;i< n;i++)
{
x=get(); y=get();
Add(x,y);
}

Point:: Dfs(1);
for(int i=1;i<=n;i++)
printf("%d\n",Ans[i]+1);
}