題目大意:給出一棵有根樹,n組詢問,每一組詢問給出樹上的一些關鍵點,問割掉一些邊使得根與這些點不聯通的最小花費是多少。總詢問的點不超過O(n)。
思路:基礎思路是每一次詢問做一次O(n)的DP,這本來已經夠快了,但是有很多詢問,這樣做就n^2了。注意到所有詢問的點加起來不超過O(n),也就是說每次詢問的點可能很少。那麼我們為何要將所有點掃一次?只需要將詢問的點重新建樹,然後跑樹形DP,這樣DP的總時間就是O(n)了。當然瓶頸在求兩點之間的最短邊上,O(nlogn)的倍增。
具體做法是維護一個單調棧,所有時刻這個棧中的所有點是從根開始的深度遞增的一條鏈。把所有點按照DFS序排序,依次加入棧中,同時維護這個棧,使它是一條鏈。假如新加進來的點與棧頂的LCA高於棧頂,那麼就說明新加進來的點不能繼續與棧頂形成鏈了。就將棧頂和次棧頂連邊,然後彈出棧頂。還有一些小細節什麼的。。
CODE:
#include#include #include #include #define MAX 510010 #define INF 0x3f3f3f3f using namespace std; struct Complex{ int x,pos; Complex(int _,int __):x(_),pos(__) {} Complex() {} bool operator <(const Complex &a)const { return pos < a.pos; } }src[MAX]; int points,asks; int head[MAX],total; int next[MAX],aim[MAX],length[MAX]; int pos[MAX],cnt; inline void Add(int x,int y,int len) { next[++total] = head[x]; aim[total] = y; length[total] = len; head[x] = total; } int father[MAX][20],_min[MAX][20]; int deep[MAX]; void DFS(int x,int last) { deep[x] = deep[last] + 1; pos[x] = ++cnt; for(int i = head[x]; i; i = next[i]) { if(aim[i] == last) continue; father[aim[i]][0] = x; _min[aim[i]][0] = length[i]; DFS(aim[i],x); } } void MakeTable() { for(int j = 1; j < 19; ++j) for(int i = 1; i <= points; ++i) { father[i][j] = father[father[i][j - 1]][j - 1]; _min[i][j] = min(_min[i][j - 1],_min[father[i][j - 1]][j - 1]); } } inline int GetLCA(int x,int y) { if(deep[x] < deep[y]) swap(x,y); for(int i = 19; ~i; --i) if(deep[father[x][i]] >= deep[y]) x = father[x][i]; if(x == y) return x; for(int i = 19; ~i; --i) if(father[x][i] != father[y][i]) x = father[x][i],y = father[y][i]; return father[x][0]; } inline int GetMin(int x,int y) { if(deep[x] < deep[y]) swap(x,y); int re = INF; for(int i = 19; ~i; --i) if(deep[father[x][i]] >= deep[y]) { re = min(re,_min[x][i]); x = father[x][i]; } for(int i = 19; ~i; --i) if(father[x][i] != father[y][i]) { re = min(re,_min[x][i]); re = min(re,_min[y][i]); x = father[x][i]; y = father[y][i]; } if(x != y) re = min(re,min(_min[x][0],_min[y][0])); return re; } struct Graph{ int head[MAX],v[MAX],T,total; int next[MAX],aim[MAX]; int super[MAX]; long long f[MAX]; void Add(int x,int y) { //cout << x << ' ' << y << endl; if(v[x] != T) { v[x] = T; head[x] = 0; } next[++total] = head[x]; aim[total] = y; head[x] = total; } void Set(int x) { super[x] = T; } void TreeDP(int x) { f[x] = 0; if(v[x] != T) { v[x] = T; head[x] = 0; } for(int i = head[x]; i; i = next[i]) { TreeDP(aim[i]); f[x] += min(super[aim[i]] == T ? INF:f[aim[i]],(long long)GetMin(x,aim[i])); } } }graph; int main() { cin >> points; for(int x,y,z,i = 1; i < points; ++i) { scanf("%d%d%d",&x,&y,&z); Add(x,y,z),Add(y,x,z); } DFS(1,0); MakeTable(); cin >> asks; for(int cnt,i = 1; i <= asks; ++i) { scanf("%d",&cnt); for(int j = 1; j <= cnt; ++j) scanf("%d",&src[j].x),src[j].pos = pos[src[j].x]; sort(src + 1,src + cnt + 1); ++graph.T; graph.total = 0; static int stack[MAX]; int top = 0; stack[++top] = 1; for(int j = 1; j <= cnt; ++j) { int lca = GetLCA(stack[top],src[j].x); while(deep[lca] < deep[stack[top]]) { if(deep[stack[top - 1]] <= deep[lca]) { int away = stack[top--]; if(stack[top] != lca) stack[++top] = lca; graph.Add(lca,away); break; } graph.Add(stack[top - 1],stack[top]),--top; } if(stack[top] != src[j].x) stack[++top] = src[j].x; graph.Set(src[j].x); } while(top) graph.Add(stack[top - 1],stack[top]),--top; graph.TreeDP(1); printf("%lld\n",graph.f[1]); } return 0; }