AcWing 356. 次小生成树(LCA求次小生成树)

Description

给定一张 n 个点 m 条边的无向图,求无向图的严格次小生成树。

设最小生成树的边权之和为 sum,严格次小生成树就是指边权之和大于 sum 的生成树中最小的一个。

Input

第一行包含两个整数 nm ( n ≤ 105 ,  m ≤ 3×105 )

接下来 m 行,每行包含三个整数 x, y, z,表示点 x 和点 y 之前存在一条边,边的权值为 z

Output

一行一个数,表示严格次小生成树的边权和。(数据保证必定存在严格次小生成树)

Sample Input

5 6
1 2 1
1 3 2
2 4 3
3 5 4
3 4 3
4 5 6

Sample Output

11

LCA求次小生成树思路与步骤:
1.先用kruskal算法求出题目的最小生成树,并标记树边与非树边

2.建图:建立所有被标记过的,属于最小生成树的边

3.枚举添加每一条非树边并对应删去一条树边后得到的新生成树,从中更新出严格次小生成树的值。

4.在第3步中,假设添加的一条非树边连接的两个端点为a,b,记lca为a,b在最小生成树上的最近公共祖先。则对应删去的一条树边,应该为最小生成树上a到lca与b到lca两条路径上的最大边,删去最大边的目的是保证得到的新生成树尽可能小。
若两条路径中的最大边与a,b的边权相等,则删去次大边,以保证新生成树为严格次小生成树;
若两条路径上所有边都与a,b边权相等,即不存在次大边,则说明添加当前的非树边a,b并不能得到严格次小生成树,因此跳过该非树边a,b,继续枚举其它非树边。因为题目保证有解,所以一定可以在枚举其它非树边时得到严格次小生成树。

5.在第4步中,每当枚举添加一条连接a,b两个端点的非树边时,我们需要对应删除一条树边,即a到lac与b到lac两条路径上的最大边。联系倍增算法求lca,类似fa[j][k],我们定义两个数组:d1[j][k]与d2[j][k],分别表示当前节点j向上移动2^k层的路径上的最大边与次大边。
当k=0时:d1[j][k] = 结点j与其父节点的连边的权重,d2[j][k] = -0x3f3f3f3f;(因为向上移动一层只有一条边,不存在次大边,所以置d2[j][k]为负无穷)
当k>0时:因为2^k=2^(k-1) + 2^(k-1),即a向上移动2^k层,等价于a向上移动2^(k-1)层后再向上移动2^(k-1)层。因此记anc = fa[j][k-1]; distance[4] = {d1[j][k-1],d2[j][k-1],d1[anc][k-1],d2[anc][k-1]};则d1[j][k]与d2[j][k]分别为distance[4]中的最大值与次大值


全部代码如下:

#define _CRT_SECURE_NO_WARNINGS 1
#include<bits/stdc++.h>
using namespace std;

#define N 100020
#define M 300020
#define INF 0x3f3f3f3f

int n, m;
int per[N];
int fa[N][20], d1[N][20], d2[N][20];
int h[N], e[2 * M], w[2 * M], ne[2 * M], idx;
int depth[N];
struct Edge {
	int a, b, c;
	bool use = false;
	bool operator < (const Edge &W)const {
		return c < W.c;
	}
}edge[M];

void add(int a, int b, int c) {
	e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
	e[idx] = a, w[idx] = c, ne[idx] = h[b], h[b] = idx++;
}

int find(int x) {
	if (x != per[x])per[x] = find(per[x]);
	return per[x];
}

long long kruskal() {
	sort(edge, edge + m);
	long long ans = 0;
	int cnt = 0;
	for (int i = 0; i < m; i++) {
		int a = edge[i].a, b = edge[i].b, c = edge[i].c;
		int fa = find(a), fb = find(b);
		if (fa != fb) {
			edge[i].use = true;
			per[fa] = fb;
			ans += c;
			if (++cnt == n - 1)return ans;
		}
	}
}

void build() {
	memset(h, -1, sizeof h);
	for (int i = 0; i < m; i++)
		if (edge[i].use)
			add(edge[i].a, edge[i].b,edge[i].c);
}

void bfs() {
	memset(depth, 0x3f, sizeof depth);
	depth[0] = 0, depth[1] = 1;
	int q[N], hh = 0, tt = -1;
	q[++tt] = 1;
	while (hh <= tt) {
		int t = q[hh++];
		for (int i = h[t]; i != -1; i = ne[i]) {
			int j = e[i];
			if (depth[j] > depth[t]) {
			    q[++tt] = j;
				depth[j] = depth[t] + 1;
				fa[j][0] = t;
				d1[j][0] = w[i], d2[j][0] = -INF;
				for (int k = 1; k <= 16; k++) {
					int ans = fa[j][k - 1];
					fa[j][k] = fa[ans][k - 1];
					int distance[4] = { d1[j][k - 1],d2[j][k - 1],d1[ans][k - 1],d2[ans][k - 1] };
					d1[j][k] = -INF, d2[j][k] = -INF;
					for (int i = 0; i < 4; i++) {
						if (distance[i] > d1[j][k])
							d2[j][k] = d1[j][k], d1[j][k] = distance[i];
						else if (distance[i] < d1[j][k] && distance[i]>d2[j][k])
							d2[j][k] = distance[i];
					}
				}
			}	
		}	
	}
}

long long lac(int a, int b, int c) {
    
	if (depth[a] < depth[b])swap(a, b);
	int distance[N], cnt = 0;
	for (int k = 16; k >= 0; k--)
		if (depth[fa[a][k]] >= depth[b]) {
			distance[cnt++] = d1[a][k];
			distance[cnt++] = d2[a][k];
			a = fa[a][k];
		}	
	
	if(a!=b){	
    	for (int k = 16; k >= 0; k--) {
    		if (fa[a][k] != fa[b][k]) {
    			distance[cnt++] = d1[a][k];
    			distance[cnt++] = d2[a][k];
    			distance[cnt++] = d1[b][k];
    			distance[cnt++] = d2[b][k];			
    			a = fa[a][k];
    			b = fa[b][k];
    		}
    	}
    	distance[cnt++] = d1[a][0];
    	distance[cnt++] = d1[b][0];
	}
	
	int dist1 = -INF, dist2 = -INF;
	for (int i = 0; i < cnt; i++) {
		if (distance[i] > dist1)
			dist2 = dist1, dist1 = distance[i];
		else if (distance[i] < dist1 && distance[i] > dist2)
			dist2 = distance[i];
	}

	if (dist1 != c)return c - dist1;
	return c - dist2;
}

int a, b, c;
int main() {

	scanf("%d%d", &n, &m);
	for (int i = 0; i < m; i++) {
		scanf("%d%d%d", &a, &b, &c);
		edge[i] = { a,b,c,false };
	}

	for (int i = 1; i <= n; i++)per[i] = i;
	long long sum = kruskal();
	build();
	bfs();
    
    long long ans = 1e18;
	for (int i = 0; i < m; i++)
		if (!edge[i].use)
			ans = min(ans, sum + lac(edge[i].a, edge[i].b, edge[i].c));

	printf("%lld\n", ans);

	return 0;
}