1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
import sys
input = sys.stdin.readline
n, m = map(int, input().split())
parent = [i for i in range(n + 1)]
def find_parent(parent, x):
if parent[x] != x:
parent[x] = find_parent(parent, parent[x])
return parent[x]
def union(parent, a, b):
a = find_parent(parent, a)
b = find_parent(parent, b)
if a < b:
parent[b] = a
elif a > b:
parent[a] = b
def solved():
for _ in range(m):
com, x, y = map(int, input().split())
if com == 0:
union(parent, x, y)
elif com == 1:
if find_parent(parent, x) == find_parent(parent, y):
print("YES")
else:
print("NO")
solved()
|
cs |
Why?
Q1. 왜 find_parent가 저렇게 작성 되어있는지? + 무엇을 위해서 저렇게 됐는지?
우선 유니온 파인드란 여러 서로소 집합의 정보를 저장하고 있는 자료구조를 의미한다.
이러한 자료구조를 만들기 위해 find는 해당 노드의 부모 노드를 찾는 함수이다.
union은 각 노드의 부모 노드를 찾은 후, 조건에 따라(ex. 둘 중 더 값이 작은 노드) 한 쪽 노드의 부모를 다른 쪽의 노드로 바꾸게 된다.
그렇다면 find_parent는
def find_parent(parent, x):
if parent[x] != x:
parent[x] = find_parent(parent, parent[x])
return parent[x]
왜 이런 식으로 구성되어 있을까?
여기에는 경로 압축 기법이라는 개념이 존재한다.
find 함수로 각 노드의 부모를 찾을 때, 각 노드에 대하여 루트 노드가 해당 노드의 부모 노드가 되게 하는 것이다.
if parent[x] != x: 부분은 해당 노드의 부모가 자기 자신이 아니라면, 그 노드의 루트 노드를 가르키게 하는 것이다.
Q2. 왜 문제 제출했을 때, 오답처리가 됐었을까?
문제를 제대로 안 읽어서이다... 문제에 잘 보면, a와 b가 같을 수도 있다는 부분을 간과했다.
문제를 잘 읽는 습관도 중요한 것 같다!