BOJ 15481 그래프와 MST

문제 : https://www.acmicpc.net/problem/15481

어떤 임의의 연결그래프가 주어졌을 때, 모든 간선에 대해서 해당 간선이 포함되는 MST의 가중치의 합을 구하라는 문제이다.

잘 생각해보면 간선을 MST에 이미 포함된 간선과 포함되지 않은 간선으로 나눌 수 있는데, 이 때 MST에 이미 포함된 간선에서의 MST 가중치 합은 너무 당연하게도 전체 MST의 가중치 합과 같다.

그럼 이 간선들로 MST를 만들었다치고 나머지 MST에 포함되지 않은 간선들의 경우는 어떨까?
해당 간선을 이루는 양 끝 정점을 (u,v)라 해보자. 이 u,v를 이으면서 트리구조를 유지하려면 어떤 간선을 끊어줘야하는데 이 말은 곧 u~v 경로 내의 한 간선을 지워야 한다는 뜻이고, 트리의 특성상 두 정점을 잇는 경로는 ulca(u,v)v로 유일하므로 이 경로 내의 한 간선을 지우되, Minimum한 Tree를 만들기위해 가중치가 가장 큰 간선을 지우라는 뜻이다.

따라서 해당 간선에서의 답은 (MST의 가중치 합) - (u~v경로 내의 최대크기 간선) + (u-v 간선가중치) 가 된다.

잘 이해가 가지 않는다면 아래 그림을 보고 임의의 두 정점을 골라서 연결을 끊어보자.
무조건 ulca(u,v)v 내의 한 간선을 골라야 함을 알 수 있다.
syntax_tree

전체 코드는 아래와 같다.

#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
using lli = long long;
struct Edge
{
	int u, v, w, idx;
	bool operator<(Edge& e) { return w < e.w; }
};

int N, M, pa[200'010], seg[4*200'010], num[200'010], hld[200'010], w[200'010], p[200'010];
int d[200'010], node_num, cc, flag[200'010];
lli mst, ans[200'010];
Edge edge[200'010];
vector<int> adj[200'010];

int f(int u);
void m(int u, int v, int w, int idx);
void dfs(int u);
void dfs2(int u);
void su(int i, int v, int n = 1, int s = 1, int e = N);
int sq(int l, int r, int n = 1, int s = 1, int e = N);
int q(int u, int v);

int main()
{
	ios_base::sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	memset(pa, -1, sizeof pa);
	cin >> N >> M;
	for (int i = 0; i < M; ++i)
	{
		int u, v, w, idx;
		cin >> u >> v >> w;
		edge[i] = { u,v,w,i };
	}
	sort(edge, edge + M);
	for (int i = 0; i < M; ++i)
		m(edge[i].u, edge[i].v, edge[i].w, edge[i].idx);
	dfs(1); dfs2(1);
	for (int i = 0; i < M; ++i)
	{
		auto& [u, v, w, idx] = edge[i];
		if (flag[idx]) su((d[u] > d[v] ? num[u] : num[v]), w);
	}
	for (int i = 0; i < M; ++i)
	{
		auto& [u, v, w, idx] = edge[i];
		if (flag[idx]) ans[idx] = mst;
		else
			ans[idx] = mst - q(u, v) + w;
	}
	for (int i = 0; i < M; ++i)
		cout << ans[i] << "\n";
}

int f(int u)
{
	return pa[u] < 0 ? u : pa[u] = f(pa[u]);
}

void m(int u, int v, int w, int idx)
{
	int u1 = u, v1 = v;
	u = f(u); v = f(v);
	if (u == v) return;
	adj[u1].emplace_back(v1);
	adj[v1].emplace_back(u1);
	flag[idx] = 1;
	pa[u] += pa[v];
	pa[v] = u;
	mst += w;
}

void dfs(int u)
{
	w[u] = 1;
	for (auto& v : adj[u])
	{
		if (w[v] == 0)
		{
			p[v] = u;
			d[v] = d[u] + 1;
			dfs(v);
			w[u] += w[v];
		}
	}
}

void dfs2(int u)
{
	int cc = -1;
	num[u] = ++node_num;
	for (auto& v : adj[u])
	{
		if (p[v] == u && (cc == -1 || w[v] > w[cc]))
			cc = v;
	}
	if (cc != -1) { hld[cc] = hld[u]; dfs2(cc); }
	for (auto& v : adj[u])
	{
		if (p[v] == u && v != cc)
		{
			hld[v] = v;
			dfs2(v);
		}
	}
}

void su(int i, int v, int n, int s, int e)
{
	if (s > i || e < i) return;
	if (s == e) { seg[n] = v; return; }
	int m = s + e >> 1;
	su(i, v, n * 2, s, m);
	su(i, v, n * 2 + 1, m + 1, e);
	seg[n] = max(seg[n * 2], seg[n * 2 + 1]);
}

int sq(int l, int r, int n, int s, int e)
{
	if (s > r || e < l) return 0;
	if (l <= s && e <= r) return seg[n];
	int m = s + e >> 1;
	return max(sq(l, r, n * 2, s, m), sq(l, r, n * 2 + 1, m + 1, e));
}

int q(int u, int v)
{
	int ret = 0;
	while (hld[u] != hld[v])
	{
		if (d[hld[u]] < d[hld[v]]) swap(u, v);
		ret = max(ret, sq(num[hld[u]], num[u]));
		u = p[hld[u]];
	}
	if (d[u] > d[v]) swap(u, v);
	return max(ret, sq(num[u] + 1, num[v]));
}

위 코드에서는 경로내의 최대크기 간선을 구하기 위해 hld를 사용하였지만 업데이트가 없는 쿼리이므로 sparse table을 구성하여도 상관이 없다.