BOJ 15481 그래프와 MST
문제 : https://www.acmicpc.net/problem/15481
어떤 임의의 연결그래프가 주어졌을 때, 모든 간선에 대해서 해당 간선이 포함되는 MST의 가중치의 합을 구하라는 문제이다.
잘 생각해보면 간선을 MST에 이미 포함된 간선과 포함되지 않은 간선으로 나눌 수 있는데, 이 때 MST에 이미 포함된 간선에서의 MST 가중치 합은 너무 당연하게도 전체 MST의 가중치 합과 같다.
그럼 이 간선들로 MST를 만들었다치고 나머지 MST에 포함되지 않은 간선들의 경우는 어떨까?
해당 간선을 이루는 양 끝 정점을 (u,v)라 해보자. 이 u,v를 이으면서 트리구조를 유지하려면 어떤 간선을 끊어줘야하는데 이 말은 곧 u~v 경로 내의 한 간선을 지워야 한다는 뜻이고, 트리의 특성상 두 정점을 잇는 경로는 ulca(u,v)v로 유일하므로 이 경로 내의 한 간선을 지우되, Minimum한 Tree를 만들기위해 가중치가 가장 큰 간선을 지우라는 뜻이다.
따라서 해당 간선에서의 답은 (MST의 가중치 합) - (u~v경로 내의 최대크기 간선) + (u-v 간선가중치) 가 된다.
잘 이해가 가지 않는다면 아래 그림을 보고 임의의 두 정점을 골라서 연결을 끊어보자.
무조건 ulca(u,v)v 내의 한 간선을 골라야 함을 알 수 있다.
전체 코드는 아래와 같다.
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
using lli = long long;
struct Edge
{
int u, v, w, idx;
bool operator<(Edge& e) { return w < e.w; }
};
int N, M, pa[200'010], seg[4*200'010], num[200'010], hld[200'010], w[200'010], p[200'010];
int d[200'010], node_num, cc, flag[200'010];
lli mst, ans[200'010];
Edge edge[200'010];
vector<int> adj[200'010];
int f(int u);
void m(int u, int v, int w, int idx);
void dfs(int u);
void dfs2(int u);
void su(int i, int v, int n = 1, int s = 1, int e = N);
int sq(int l, int r, int n = 1, int s = 1, int e = N);
int q(int u, int v);
int main()
{
ios_base::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
memset(pa, -1, sizeof pa);
cin >> N >> M;
for (int i = 0; i < M; ++i)
{
int u, v, w, idx;
cin >> u >> v >> w;
edge[i] = { u,v,w,i };
}
sort(edge, edge + M);
for (int i = 0; i < M; ++i)
m(edge[i].u, edge[i].v, edge[i].w, edge[i].idx);
dfs(1); dfs2(1);
for (int i = 0; i < M; ++i)
{
auto& [u, v, w, idx] = edge[i];
if (flag[idx]) su((d[u] > d[v] ? num[u] : num[v]), w);
}
for (int i = 0; i < M; ++i)
{
auto& [u, v, w, idx] = edge[i];
if (flag[idx]) ans[idx] = mst;
else
ans[idx] = mst - q(u, v) + w;
}
for (int i = 0; i < M; ++i)
cout << ans[i] << "\n";
}
int f(int u)
{
return pa[u] < 0 ? u : pa[u] = f(pa[u]);
}
void m(int u, int v, int w, int idx)
{
int u1 = u, v1 = v;
u = f(u); v = f(v);
if (u == v) return;
adj[u1].emplace_back(v1);
adj[v1].emplace_back(u1);
flag[idx] = 1;
pa[u] += pa[v];
pa[v] = u;
mst += w;
}
void dfs(int u)
{
w[u] = 1;
for (auto& v : adj[u])
{
if (w[v] == 0)
{
p[v] = u;
d[v] = d[u] + 1;
dfs(v);
w[u] += w[v];
}
}
}
void dfs2(int u)
{
int cc = -1;
num[u] = ++node_num;
for (auto& v : adj[u])
{
if (p[v] == u && (cc == -1 || w[v] > w[cc]))
cc = v;
}
if (cc != -1) { hld[cc] = hld[u]; dfs2(cc); }
for (auto& v : adj[u])
{
if (p[v] == u && v != cc)
{
hld[v] = v;
dfs2(v);
}
}
}
void su(int i, int v, int n, int s, int e)
{
if (s > i || e < i) return;
if (s == e) { seg[n] = v; return; }
int m = s + e >> 1;
su(i, v, n * 2, s, m);
su(i, v, n * 2 + 1, m + 1, e);
seg[n] = max(seg[n * 2], seg[n * 2 + 1]);
}
int sq(int l, int r, int n, int s, int e)
{
if (s > r || e < l) return 0;
if (l <= s && e <= r) return seg[n];
int m = s + e >> 1;
return max(sq(l, r, n * 2, s, m), sq(l, r, n * 2 + 1, m + 1, e));
}
int q(int u, int v)
{
int ret = 0;
while (hld[u] != hld[v])
{
if (d[hld[u]] < d[hld[v]]) swap(u, v);
ret = max(ret, sq(num[hld[u]], num[u]));
u = p[hld[u]];
}
if (d[u] > d[v]) swap(u, v);
return max(ret, sq(num[u] + 1, num[v]));
}
위 코드에서는 경로내의 최대크기 간선을 구하기 위해 hld를 사용하였지만 업데이트가 없는 쿼리이므로 sparse table을 구성하여도 상관이 없다.