BOJ 16213 dgeu-learning

출처 : https://www.acmicpc.net/problem/16213

요즘 HLD풀고 있는데 구현량이 꽤 있는 문제를 봐서 여기다가 올린다. (C한정)
참고로 난 구현량이 많으면 좋...다.... 힣...

오랜만에 정렬을 구현해 보고 싶다면 이 문제를 추천한다.

풀이 :
MAXIMUM spanning tree를 만든다.

만들어진 트리를 가지고 heavy-light decomposition을 돌린다.

min segment tree를 만든다.

엣지사이의 값을가지고 분할한 트리와 segment tree를 갱신한다.

쿼리에 대한 질의는 HLD를 좀 풀어본 사람이라면 답을 구할 수 있을 것이다.
HLD가 뭔지 모르겠다면 이 글을 보자

아래는 코드이다.

#include<stdio.h>
#include<malloc.h>
#define inf 1000000010
struct edge
{
	int a, b, c;
};
int vn, en, qn, lidx;
struct edge origin[500010], sorted[500010], list[200010];
int up[200010], ur[200010];
int *elist[200010], eindex[200010];
int size[200010], parent[200010];
int cindex[200010], chain_num[200010], chain_depth[200010], chain_index[200010], dnum[200010], num;
int node, s, e, l, r, g, * tree;
void input();
void make_tree();
void init_set();
void msort(struct edge[], int, int);
void merge(struct edge[], int, int, int);
int cmp(struct edge, struct edge);
int find(int);
void uni(int, int);
void lnk(int, int);
void pb(int, int);
void pbl(int, int, int);
void make_hld();
int go1(int, int);
void go2(int, int, int, int);
void make_segment_tree();
void seg_update(int, int);
int get_min(int, int);
void proc();
int query(int, int);
int lca(int, int);
int hq(int, int);
int seg_query(int, int, int, int, int);
void fin();
void make_answer();
int main(void)
{
	make_answer();
	return 0;
}
void make_answer()
{
	input();
	make_tree();
	make_hld();
	make_segment_tree();
	proc();
	fin();
}
void input()
{
	int i;
	scanf("%d %d %d", &vn, &en, &qn);
	for (i = 0, num = 1; i < en; i++)
	{
		scanf("%d %d %d", &origin[i].a, &origin[i].b, &origin[i].c);
	}
}
void make_tree()
{
	int i;
	init_set();
	msort(origin, 0, en - 1);
	for (i = 0; i < en; i++)
	{
		if (find(origin[i].a) != find(origin[i].b))
		{
			uni(origin[i].a, origin[i].b);
			pb(origin[i].a, origin[i].b);
			pb(origin[i].b, origin[i].a);
			pbl(origin[i].a, origin[i].b, origin[i].c);
		}
	}
}
void init_set()
{
	int i;
	for (i = 0; i <= vn; i++)
	{
		up[i] = i;
	}
}
void msort(struct edge list[], int left, int right)
{
	int mid;
	if (left < right)
	{
		mid = (left + right) >> 1;
		msort(list, left, mid);
		msort(list, mid + 1, right);
		merge(list, left, mid, right);
	}
}
void merge(struct edge list[], int left, int mid, int right)
{
	int i, j, k, l;
	i = k = left;
	j = mid + 1;
	while (i <= mid && j <= right)
	{
		if (cmp(list[i], list[j]))
		{
			sorted[k++] = list[i++];
		}
		else
		{
			sorted[k++] = list[j++];
		}
	}
	if (mid < i)
	{
		for (l = j; l <= right; l++)
		{
			sorted[k++] = list[l];
		}
	}
	else
	{
		for (l = i; l <= mid; l++)
		{
			sorted[k++] = list[l];
		}
	}
	for (l = left; l <= right; l++)
	{
		list[l] = sorted[l];
	}
}
int cmp(struct edge a, struct edge b)
{
	return a.c > b.c;
}
int find(int x)
{
	if (x == up[x])
	{
		return x;
	}
	return up[x] = find(up[x]);
}
void uni(int x, int y)
{
	lnk(find(x), find(y));
}
void lnk(int x, int y)
{
	if (ur[x] > ur[y])
	{
		up[y] = x;
	}
	else
	{
		up[x] = y;
		if (ur[x] == ur[y])
		{
			ur[y]++;
		}
	}
}
void pb(int a, int b)
{
	elist[a] = (int*)realloc(elist[a], sizeof(int) * (eindex[a] + 1));
	elist[a][eindex[a]++] = b;
}
void pbl(int a, int b, int c)
{
	list[lidx].a = a;
	list[lidx].b = b;
	list[lidx].c = c;
	lidx++;
}
void make_hld()
{
	go1(1, 0);
	go2(1, 0, 1, 0);
}
int go1(int cur, int par)
{
	int d, next;
	size[cur] = 1;
	parent[cur] = par;
	for (d = 0; d < eindex[cur]; d++)
	{
		next = elist[cur][d];
		if (next != par)
		{
			size[cur] += go1(next, cur);
		}
	}
	return size[cur];
}
void go2(int cur, int par, int cur_chain_num, int cur_chain_depth)
{
	int d, h, next;
	dnum[cur] = num++;
	chain_num[cur] = cur_chain_num;
	chain_depth[cur] = cur_chain_depth;
	chain_index[cur] = cindex[cur_chain_num]++;
	for (d = 0, h = -1; d < eindex[cur]; d++)
	{
		next = elist[cur][d];
		if (next != par && (h == -1 || size[next] > size[h]))
		{
			h = next;
		}
	}
	if (h != -1)
	{
		go2(h, cur, cur_chain_num, cur_chain_depth);
	}
	for (d = 0; d < eindex[cur]; d++)
	{
		next = elist[cur][d];
		if (next != par && next != h)
		{
			go2(next, cur, next, cur_chain_depth + 1);
		}
	}
}
void make_segment_tree()
{
	int i, origin_node_num;
	for (node = s = l = r = 1; r - l + 1 < vn; l <<= 1, r <<= 1, r |= 1);
	for (tree = (int*)malloc(sizeof(int) * (r + 1)), i = 0, g = l - 1, e = r - l + 1; i < r + 1; i++)
	{
		tree[i] = inf;
	}
	for (i = 0; i < lidx; i++)
	{
		origin_node_num = parent[list[i].a] == list[i].b ? list[i].a : list[i].b;
		seg_update(g + dnum[origin_node_num], list[i].c);
	}
}
void seg_update(int idx, int val)
{
	tree[idx] = val;
	idx >>= 1;
	while (idx)
	{
		tree[idx] = get_min(tree[idx << 1], tree[idx << 1 | 1]);
		idx >>= 1;
	}
}
int get_min(int a, int b)
{
	return a > b ? b : a;
}
void proc()
{
	int a, b, ans;
	while (qn--)
	{
		scanf("%d %d", &a, &b);
		ans = query(a, b);
		printf("%d\n", ans);
	}
}
int query(int a, int b)
{
	int l, ans1, ans2, ans;
	l = lca(a, b);
	ans1 = hq(l, a);
	ans2 = hq(l, b);
	ans = get_min(ans1, ans2);
	return ans;
}
int lca(int a, int b)
{
	while (chain_num[a] != chain_num[b])
	{
		if (chain_depth[a] > chain_depth[b])
		{
			a = parent[chain_num[a]];
		}
		else
		{
			b = parent[chain_num[b]];
		}
	}
	return chain_index[a] > chain_index[b] ? b : a;
}
int hq(int a, int b)
{
	int ans = inf;
	while (chain_num[a] != chain_num[b])
	{
		ans = get_min(ans, seg_query(node, s, e, dnum[chain_num[b]], dnum[b]));
		b = parent[chain_num[b]];
	}
	ans = get_min(ans, seg_query(node, s, e, dnum[a] + 1, dnum[b]));
	return ans;
}
int seg_query(int node, int start, int end, int left, int right)
{
	int mid;
	if (left > end || right < start)
	{
		return inf;
	}
	if (left <= start && end <= right)
	{
		return tree[node];
	}
	mid = (start + end) >> 1;
	return get_min(seg_query(node << 1, start, mid, left, right), seg_query(node << 1 | 1, mid + 1, end, left, right));
}
void fin()
{
	int i;
	for (i = 0, free(tree); i <= vn; i++)
	{
		if (elist[i])
		{
			free(elist[i]);
		}
	}
}