這個題有兩種做法
1.並查集
初始時一條邊都不加,將所有邊按權值從大到小排序,然後依次判斷每一個邊兩端的頂點是否是均為machine節點,如果是則應刪除這條邊,否則加入這條邊,然後在並查集合並時盡量讓根節點為machine節點。
[cpp]
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<queue>
#include<math.h>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
int cnt;
int n,k;
ll sum;
bool is[100100];
int fa[100100];
struct Edge{
int from,to,cost;
}edge[100100];
void addedge(int s,int t,int c){
edge[cnt].from=s;
edge[cnt].to=t;
edge[cnt++].cost=c;
}
void init(){
cnt=0;
memset(is,0,sizeof(is));
for(int i=0;i<n;i++)
fa[i]=i;
sum=0;
}
bool cmp(struct Edge a,struct Edge b){
return a.cost<b.cost;
}
int query(int i){
if(fa[i]!=i)
fa[i]=query(fa[i]);
return fa[i];
}
void merge(int a,int b){
int aa=query(a);
int bb=query(b);
if(is[bb])
fa[aa]=bb;
else
fa[bb]=aa;
}
int main(){
int t,T,i,j,u,v,w,uu,vv;
int tem,p,q;
bool flag;
scanf("%d",&T);
for(t=1;t<=T;t++){
scanf("%d %d",&n,&k);
init();
for(i=1;i<n;i++){
scanf("%d %d %d",&u,&v,&w);
addedge(u,v,w);
}
for(i=1;i<=k;i++){
scanf("%d",&tem);
is[tem]=1;
}
sort(edge,edge+cnt,cmp);
for(i=cnt-1;i>=0;i--){
flag=0;
uu=query(edge[i].from);
vv=query(edge[i].to);
if(uu!=vv && is[uu] && is[vv])
flag=1;
if(flag){
sum+=edge[i].cost;
//printf("****%d %d %d\n",i,edge[i].from,edge[i].to);
}
else
merge(edge[i].from,edge[i].to);
}
cout<<sum<<endl;
}
}
2.樹形DP
這裡dis[i]數組保存的是從i節點向下走到machine節點的最小邊權,如果i節點本身為machine節點,那麼dis[i]=MAX.
當前的根節點如果是machine節點,那麼與根節點相連的所有能走到machine節點的節點的都要刪去。
當前的根節點如果是machine節點,那麼要留下一個刪除費用最大的節點用於向上更新。
[cpp]
#include<cstdio>
#include<cstring>
#include<vector>
#include<queue>
#include<iostream>
#include<algorithm>
using namespace std;
const long long MAX=(~(0ULL)>>1);
typedef long long ll;
int cnt,head[100100];
int n,k;
ll sum;
bool is[100100];
ll dis[100100];
vector<ll>vec[100100];
struct Edge{
int to,cost,next;
}edge[200100];
void addedge(int s,int t,int c){
edge[cnt].to=t;
edge[cnt].next=head[s];
edge[cnt].cost=c;
head[s]=cnt++;
edge[cnt].to=s;
edge[cnt].next=head[t];
edge[cnt].cost=c;
head[t]=cnt++;
}
void init(){
cnt=0;
memset(head,-1,sizeof(head));
memset(is,0,sizeof(is));
sum=0;
for(int i=0;i<n;i++)
vec[i].clear();
}
ll mini(ll a,ll b){
return a>b?b:a;
}
void dfs(int s,int father){
int i,j;
int all=0;
ll s1=0,s2=0;
for(i=head[s];i!=-1;i=edge[i].next){
j=edge[i].to;
if(j!=father){
dfs(j,s);
if(is[j])
vec[s].push_back(mini(dis[j],edge[i].cost));
}
}
if(vec[s].size()==0){
dis[s]=MAX;
return;
}
sort(vec[s].begin(),vec[s].end());
for(i=0;i<vec[s].size()-1;i++)
sum+=(ll)vec[s][i];
if(is[s]){
sum+=(ll)vec[s][i];
dis[s]=MAX;
}
else{
is[s]=1;
dis[s]=vec[s][i];
}
}
int main(){
int t,T,i,j,u,v,w,tem;
scanf("%d",&T);
for(t=1;t<=T;t++){
scanf("%d %d",&n,&k);
init();
for(i=1;i<n;i++){
scanf("%d %d %d",&u,&v,&w);
addedge(u,v,w);
}
for(i=1;i<=k;i++){
scanf("%d",&tem);
is[tem]=1;
}
memset(dis,0,sizeof(dis));
dfs(0,-1);
cout<<sum<<endl;
}
}