BOJ 17274 카드 공장

문제 : https://www.acmicpc.net/problem/17274

지문이 짧고 간단하다.
앞면은 A, 뒷면은 B로 카드쌍이 주어지는데 K라는 숫자가 주어질 때마다 전체 카드 중 현재 보이고 있는 면이 K이하인 카드를 모두 뒤집어주면 된다.

가장 먼저 떠오르는 방법은 매 쿼리마다 전체 카드를 보면서 하나씩 뒤집어주는 것이지만 전체 카드의 개수와 쿼리의 개수는 각각 20만이니 O(NM)으로는 너무 당연하게 터진다.

보통 이런 문제를 만나면 아니이딴걸어케품? 이란 생각이 들지만 이 문제는 다행히도 small 버전이 있으니 small버전을 보면서 힌트를 얻어보자.

카드 공장(small) : https://www.acmicpc.net/problem/17273

small과 large의 가장 큰 차이는 small은 N이 1이라는 것이다. 단순히 카드 한 장을 보면서 현재 보는 값이 K이하면 뒤집어주면된다. 이 과정을 조금 더 자세히 살펴보자. 카드에 있는 두 숫자 중 작은 쪽을 min, 큰 쪽을 max라고 하면 K가 min보다 작으면 카드는 현재 상태를 유지한다. 반대로 K>=max라면 카드는 항상 뒤집힌다. 남은 구간은 min<=K<max인데 이 구간에서는 현재 값이 min이면 max로 뒤집히고 max라면 그대로 유지한다. 이 문제의 핵심은 이 관찰을 하는 것인데, min<=K<max 일경우 이전값에 상관없이 항상 카드가 max쪽이 앞면이 되기 때문에 이전까지 카드가 어떤면이었는지 전혀 신경쓰지 않아도 되기 때문이다.

min<=K<max 인 K값의 마지막 인덱스를 Klast라 하면 Klast 이전의 값은 신경쓸필요가 없어졌으니 Klast 이후의 K값들에 대해서 생각해보자.
이 이후의 K값들은 K>=max인 경우에만 카드가 뒤집힐텐데 우리는 현재 카드의 앞면이 max임을 알고있으니 Klast이후에 총 몇 번 카드가 뒤집히는지만 알면 마지막에 카드 앞면에 쓰여있는 숫자가 min인지 max인지 바로 알 수 있다.

또한 Klast가 존재하지 않는경우 카드는 앞면인 상태로 시작하므로 전체 범위에서 max보다 큰 K의 개수만 찾으면 이 개수가 짝수일경우 앞면, 홀수일경우 뒷면이 최종 결과가 될 것이다.

로직을 정리해보자.

  1. 각 카드에 대해 Klast를 찾는다.
  2. (Klast,M] 구간에 대해 max보다 큰 K가 총 몇 번 나오는지 센다.
  3. Klast가 존재할경우 센 결과가 짝수이면 max, 홀수이면 min, 존재하지않을경우 센 결과가 짝수이면 앞면, 홀수이면 뒷면이 최종 결과.

1번의 경우 max세그에 K별로 인덱스를 전부 때려박고 각 카드에 대해 [min,max)쿼리를 날려주면되니 K가 최대 10억임을 고려해서 좌표압축을하든 동적세그먼트트리를 짜든해서 바로 처리해줄수있다.
2번은 조금 생각을 해야하는데 [l,r]구간에서 K보다 큰 수의 개수를 찾아야하는데 잘 보면 r은 항상 M으로 고정이 되어있으므로 카드들을 Klast내림차순으로 정렬을 한 다음 지금 보고있는 카드의 Klast+1 ~ M 까지 인덱스의 K값을 전부 세그에 박고 이 중 max보다 큰 값의 개수를 세면 된다. 근데 짜고보니 그냥 머지소트트리나 pst를 짰어도 됐겠네 아..

전체 코드는 아래와 같다.

seg1과 seg2는 각각 1번2번 로직을 처리하기 위한 용도이다.

#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#include <bits/stdc++.h>
using namespace std;
using lli = long long;
using tu = tuple<int, int, int>;

struct Seg
{
	int l, r, v;
	Seg(int val) : l(0), r(0), v(val) {}
};

const int MAX = 1'000'000'000;
int N, M, arr2[200'010];
lli ans;
tu arr1[200'010];
vector<Seg> seg1(2, Seg(-1)), seg2(2, Seg(0));

void u1(int i, int v, int n = 1, int s = 1, int e = MAX);
int q1(int l, int r, int n = 1, int s = 1, int e = MAX);
void u2(int i, int n = 1, int s = 1, int e = MAX);
int q2(int l, int r, int n = 1, int s = 1, int e = MAX);

int main()
{
	ios_base::sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	cin >> N >> M;
	for (int i = 0; i < N; ++i)
	{
		int a, b; cin >> a >> b;
		arr1[i] = { a,b,-1 };
	}
	for (int i = 0; i < M; ++i)
	{
		cin >> arr2[i];
		u1(arr2[i], i);
	}
	for (int i = 0; i < N; ++i)
	{
		auto [mn, mx, idx] = arr1[i];
		if (mn > mx) swap(mn, mx);
		get<2>(arr1[i]) = q1(mn, mx - 1);
	}
	sort(arr1, arr1 + N, [](tu& t1, tu& t2) {return get<2>(t1) > get<2>(t2); });
	int cur = M - 1;
	for (int i = 0; i < N; ++i)
	{
		auto [a, b, idx] = arr1[i];
		while (cur > idx) u2(arr2[cur--]);
		int cnt = q2(max(a, b), MAX);
		if (idx == -1) ans += cnt & 1 ? b : a;
		else ans += cnt & 1 ? min(a, b) : max(a, b);
	}
	cout << ans;
}

void u1(int i, int v, int n, int s, int e)
{
	if (e<i || s>i) return;
	if (s == e) { seg1[n] = v; return; }
	int m = s + e >> 1;
	if (i <= m)
	{
		if (!seg1[n].l)
		{
			seg1[n].l = seg1.size();
			seg1.emplace_back(Seg(-1));
		}
		u1(i, v, seg1[n].l, s, m);
	}
	else
	{
		if (!seg1[n].r)
		{
			seg1[n].r = seg1.size();
			seg1.emplace_back(Seg(-1));
		}
		u1(i, v, seg1[n].r, m + 1, e);
	}
	seg1[n].v = max(seg1[seg1[n].l].v, seg1[seg1[n].r].v);
}

int q1(int l, int r, int n, int s, int e)
{
	if (e<l || s>r) return -1;
	if (l <= s && e <= r) return seg1[n].v;
	int m = s + e >> 1;
	int ret = -1;
	if (seg1[n].l) ret = max(ret, q1(l, r, seg1[n].l, s, m));
	if (seg1[n].r) ret = max(ret, q1(l, r, seg1[n].r, m + 1, e));
	return ret;
}

void u2(int i, int n, int s, int e)
{
	if (s > i || e < i) return;
	if (s == e) { seg2[n].v++; return; }
	int m = s + e >> 1;
	if (i <= m)
	{
		if (!seg2[n].l)
		{
			seg2[n].l = seg2.size();
			seg2.emplace_back(Seg(0));
		}
		u2(i, seg2[n].l, s, m);
	}
	else
	{
		if (!seg2[n].r)
		{
			seg2[n].r = seg2.size();
			seg2.emplace_back(Seg(0));
		}
		u2(i, seg2[n].r, m + 1, e);
	}
	seg2[n].v = seg2[seg2[n].l].v + seg2[seg2[n].r].v;
}

int q2(int l, int r, int n, int s, int e)
{
	if (s > r || e < l) return 0;
	if (l <= s && e <= r) return seg2[n].v;
	int m = s + e >> 1;
	int ret = 0;
	if (seg2[n].l) ret += q2(l, r, seg2[n].l, s, m);
	if (seg2[n].r) ret += q2(l, r, seg2[n].r, m + 1, e);
	return ret;
}