BOJ 5916 농장관리

나는 야호다.

오늘은 농장관리라는 문제를 풀어볼거다.

문제: https://www.acmicpc.net/problem/5916

문제는 HLD의 기본문제다. HLD의 설명은 /hld 를 봐라.
HLD의 기본만 알고있다면, 약간 까다로운 점이 존재하는데, 바로 이 문제의 쿼리가 노드에 있는게 아니라, 간선에 있다는 점이다.

간선 쿼리의 경우에는 그 간선이 자식쪽에 해당하는 노드에 속해있다고 생각하면 된다.
따라서 경로 쿼리를 할때도 lca에 해당하는 점은 빼고 계산해야, 올바른 경로의 쿼리를 구할수 있다.

#include<iostream>
#include<vector>
#include<string>
using namespace std;
typedef unsigned long long ll;
typedef pair<ll, ll> pll;

/* 데이터 타입과 lazy의 identity 수정*/
/*--------------------------------------------*/
typedef ll TDATA;
typedef ll TLAZY;
const TLAZY ID = 0;
/* calc
	data에 lazy를 적용
	input
		lazy : 계산할 lazy값
		data : lazy를 적용할 data
		cnt : 세그먼트의 크기
*/
void calc(TLAZY lazy, TDATA& data, int cnt)
{
	data += lazy;
}
/* compose
	lazy 2개를 합성, y를 x o y로 바꿈
	input
		x, y : 위에서 주어짐
*/
void compose(TLAZY x, TLAZY& y)
{
	y += x;
}

/* merge
	세그먼트 트리의 노드 2개를 병합
	input
		l,r : 병합할 노드들
*/
TDATA merge(TDATA l, TDATA r)
{
	return l + r;
}
/*--------------------------------------------*/
class Segtree
{
private:
	vector<TDATA> v;
	vector<TLAZY> lazy;
	int sz;
	void resize(int n)
	{
		sz = n;
		v.resize(sz * 4);
		lazy.resize(sz * 4, ID);
	}
	TDATA init(vector<TDATA>& data, int node, int s, int e)
	{
		if (s + 1 == e) return v[node] = data[s];
		int mid = (s + e) / 2;
		return v[node] = merge(init(data, node * 2, s, mid), init(data, node * 2 + 1, mid, e));
	}
	TDATA query(int node, int s, int e, int l, int r)
	{
		propagate(node, s, e);
		if (r <= s || e <= l) return 0;
		if (l <= s && e <= r) return v[node];
		int mid = (s + e) / 2;
		return merge(query(node * 2, s, mid, l, r), query(node * 2 + 1, mid, e, l, r));
	}
	void propagate(int node, int s, int e)
	{
		if (lazy[node] == ID) return;
		if (s + 1 != e)
		{
			compose(lazy[node], lazy[node * 2]);
			compose(lazy[node], lazy[node * 2 + 1]);
		}
		calc(lazy[node], v[node], e - s);
		lazy[node] = ID;
	}

	void update(int node, int s, int e, int l, int r, TLAZY act)
	{
		propagate(node, s, e);
		if (r <= s || e <= l)return;
		if (l <= s && e <= r)
		{
			compose(act, lazy[node]);
			propagate(node, s, e);
			return;
		}
		int mid = (s + e) / 2;
		update(node * 2, s, mid, l, r, act);
		update(node * 2 + 1, mid, e, l, r, act);
		v[node] = merge(v[node * 2], v[node * 2 + 1]);
	}
	TDATA query(int l, int r)
	{
		return query(1, 0, sz, l, r);
	}
	void update(int l, int r, TLAZY act)
	{
		update(1, 0, sz, l, r, act);
	}

public:
	Segtree(int n = 0)
	{
		resize(n);
	}
	Segtree(vector<TDATA>& data)
	{
		resize(data.size());
		init(data, 1, 0, sz);
	}
	// 원하는 쿼리들을 작성
	TDATA range_sum(int l, int r)
	{
		return query(l, r);
	}
	void range_update(int l, int r)
	{
		update(l, r, 1LL);
	}
};

class HeavyLightDecomposition
{
private:
	vector< vector<int>> adj;
	vector<int> depth, par, weight, num, hedge, top, esc;
	Segtree* seg;
	int n;
	void dfs(int now, int prev)
	{
		par[now] = prev;
		weight[now] = 1;
		for (auto& x : adj[now])
		{
			if (x == prev)continue;
			depth[x] = depth[now] + 1;
			dfs(x, now);
			weight[now] += weight[x];
			if (hedge[now] == 0 || weight[x] > weight[hedge[now]])
				hedge[now] = x;
		}
	}
	int c = 0;
	void dfs2(int now, int prev)
	{
		num[now] = c++;
		if (hedge[now])
		{
			top[hedge[now]] = top[now];
			dfs2(hedge[now], now);
			for (auto& x : adj[now])
			{
				if (x == prev || x == hedge[now])continue;
				top[x] = x;
				dfs2(x, now);
			}
		}
		esc[now] = c;
	}
public:
	HeavyLightDecomposition(int _n) :n{ _n }
	{
		adj.resize(n + 1);
		depth.resize(n + 1);
		par.resize(n + 1);
		weight.resize(n + 1);
		num.resize(n + 1);
		hedge.resize(n + 1);
		top.resize(n + 1);
		esc.resize(n + 1);
		seg = new Segtree(n);
	}
	void add_edge(int x, int y)
	{
		adj[x].push_back(y);
		adj[y].push_back(x);
	}
	void construct(int root)
	{
		dfs(root, root);
		top[root] = root;
		dfs2(root, root);
	}
	// 원하는 쿼리들을 작성
	void qp(int x, int y)
	{
		while (top[x] != top[y])
		{
			if (depth[top[x]] < depth[top[y]])swap(x, y);
			seg->range_update(num[top[x]], num[x] + 1);
			x = par[top[x]];
		}
		if (depth[x] > depth[y])swap(x, y);
		seg->range_update(num[x] + 1, num[y] + 1);
	}
	TDATA qq(int x, int y)
	{
		TDATA ret = 0;
		while (top[x] != top[y])
		{
			if (depth[top[x]] < depth[top[y]]) swap(x, y);
			ret = merge(ret, seg->range_sum(num[top[x]], num[x] + 1));
			x = par[top[x]];
		}
		if (depth[x] > depth[y]) swap(x, y);
		ret = merge(ret, seg->range_sum(num[x] + 1, num[y] + 1));
		return ret;
	}
};
int main()
{
	cin.tie(NULL);
	cin.sync_with_stdio(false);
	int n, q;
	cin >> n >> q;
	HeavyLightDecomposition hld = HeavyLightDecomposition(n);
	for (int i = 1; i < n; i++)
	{
		int x, y;
		cin >> x >> y;
		hld.add_edge(x, y);
	}
	hld.construct(1);
	while (q--)
	{
		string s;
		int x, y;
		cin >> s >> x >> y;
		if (s == "P")
		{
			hld.qp(x, y);
		}
		else
		{
			cout << hld.qq(x, y) << "\n";
		}
	}
}