그래프가 주어졌을 때, 그 그래프의 최소 스패닝 트리를 구하는 프로그램을 작성하시오.
최소 스패닝 트리는, 주어진 그래프의 모든 정점들을 연결하는 부분 그래프 중에서 그 가중치의 합이 최소인 트리를 말한다.
첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.
그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.
첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.
입력 1
3 3
1 2 1
2 3 2
1 3 3
출력 1
3
최소 스패닝 트리를 구하는 문제 -> 크루스칼 알고리즘 선택
우선 스패닝 트리는 무방향 연결 그래프에서 사이클 없이 모든 정점을 포함하는 트리를 의미합니다.
그 중에서 최소 스패닝 트리는 스패닝 트리 중에서 가중치 합이 최소인 트리입니다.
최소 스패닝 트리를 구하는 알고리즘 중에 크루스칼 알고리즘이 있습니다.
크루스칼 알고리즘은 다음과 같이 진행됩니다.
- 모든 간선을 가중치 크기의 오름차순으로 정렬합니다.
- 이를 통해 그래프 상에서 가중치가 제일 작은 간선부터 선택할 수 있게 됩니다.
G = sorted(G, key=lambda x:x[2])
- 모든 정점의 부모 노드를 자기 자신으로 초기화합니다.
- 여기서 부모 노드는 스패닝 트리 상 부모 노드를 의미합니다.
parent = [i for i in range(V + 1)]
- 가중치 크기가 작은 순서대로 간선을 확인하면서, 해당 간선으로 이어진 두 정점의 부모 노드를 find합니다.
- if 부모 노드 다르면: 같은 스패닝 트리로 union
- else if 부모 노드 같으면: 같은 스패닝 트리에 포함되어 있다는 의미.
- 이미 이전에 두 정점이 최소의 가중치의 간선으로 이어졌기 때문에 pass한다.
for a, b, cost in G:
if not find(a) == find(b):
answer += cost
union(a, b)
def find(v):
global parent
if parent[v] == v:
return v
else:
return find(parent[v])
def union(a, b):
global parent
a = find(a)
b = find(b)
if a < b:
parent[a] = b
else:
parent[b] = a
- 3번을 진행하고 나면 모든 정점에 대해 부모 노드를 find할 경우 동일한 부모 노드를 가지게 됩니다.
크루스칼 알고리즘은 E개의 간선이 있을 때 O(ElogE) 시간복잡도를 가집니다.
E가 최대 10만이므로 통과가 가능합니다.
find 연산: O(V) (skewed tree 기준)
간선 정렬: O(ElogE) -> 시간복잡도
- 그래프 경로의 가중치 합이 최소임을 찾는 문제인 경우, 조건에 부합하는 모든 정점을 포함하는 문제인지 아닌지에 따라,
구할 것이 최단 경로인지 최소 스패닝 트리인지 확인하는 행동이 필요해 보입니다.
import sys
sys.setrecursionlimit(10**6)
input = sys.stdin.readline
def find(v):
global parent
if parent[v] == v:
return v
else:
return find(parent[v])
def is_cycle(a, b):
return True if find(a) == find(b) else False
def union(a, b):
global parent
a = find(a)
b = find(b)
if a < b:
parent[a] = b
else:
parent[b] = a
if __name__ == '__main__':
V, E = map(int, input().split())
G = [list(map(int, input().split())) for _ in range(E)]
G = sorted(G, key=lambda x:x[2])
answer = 0
parent = [i for i in range(V + 1)]
for a, b, cost in G:
if not is_cycle(a, b):
answer += cost
union(a, b)
print(answer)
최소 스패닝 트리의 튜토리얼 격 문제. 두 가지의 접근법이 있는데, 하나는 Edge-Based MST, 다른 하나는 Node-Based MST다.
Edge-Based MST는 간선을 기준으로 하여, 최단 거리의 간선부터 먼저 잇고, 그 과정에서 사이클이 생성되면 간선을 제거함으로써 최소 스패닝 트리를 만드는 방법이다.
이 알고리즘을 흔히 'Kruskal Algorithm'이라고 한다.
Kruskal Algorithm은 가장 짧은 간선부터 시작하여, 간선을 추가하기 시작하고, 이 과정에서 노드들을 그룹화 한다.
간선을 생성하면 해당 노드가 속한 그룹들을 판별하고, Ranking이 높은 그룹으로 해당 노드를 편입시킴으로써 알고리즘의 끝에 도달하면 하나의 그룹으로 노드가 자리매김하게 된다.
Node-Basded MST는 노드를 기준으로 하여, 시작 노드부터 시작해서 해당 노드의 가장 짧은 간선을 우선순위로 두어 노드들을 탐색하기 시작한다.
다익스트라와 동일한 Greedy 접근법으로, 다른 점은 다익스트라와 다르게 모든 노드가 서로 연결되었다는 것을 전제조건으로 삼아야 이를 실행할 수 있다. 그래서 실제로 다익스트라와 꽤 유사한 구조를 띄고 있다. 하나 다른 점은, 도착점이 정해져있지 않고 모든 노드가 방문한 상태가 종료 조건이 된다.
이를 Prim's Algorithm이라고 한다.
따라서 두가지 접근법을 두고 시행했다.
먼저 간선 접근법을 사용하여 해당 문제를 풀어보았다.
v,e = map(int, input().split())
edges = [[] for _ in range(e)]
parent = [i for i in range(v+1)]
rank = [0 for _ in range(v+1)]
- edges는 간선을 저장하는 배열이다.
- parent는 노드의 부모를 저장하는 배열이다. 초기에는 자신이 자신의 부모가 되도록 처리한다.
- rank는 노드의 그룹의 우선순위를 저장하는 배열이다.
def find(a):
if a == parent[a]:
return a
pa = find(parent[a])
parent[a] = pa
return parent[a]
- parameter값의 a가 부모와 동일하다면 최상위 부모를 찾았다는 의미로, 종료조건을 설정한다.
- 그것이 아니라면 재귀적으로 노드를 거슬러 올라가 부모를 찾는다.
def union(a,b):
a = find(a)
b = find(b)
if a == b:
return
if rank[a] > rank[b]:
parent[b] = a
else:
parent[a] = b
if rank[a] == rank[b]:
rank[b] += 1
- 노드 a,b를 연결하기 위한 함수다.
- a,b의 최상위 부모를 찾고, 만일 부모가 동일하다면 이 a,b는 이미 연결된 노드이므로 아무 행동도 하지 않는다.
- 아니라면 두 노드 a,b를 연결하기 위해 우선순위를 먼저 고려한 후 연결한다.
for i in range(e):
a,b,c = map(int,input().split())
edges[i] = [c,a,b]
edges.sort()
total = 0
for edge in edges:
if not edge:
continue
cost, a, b = edge
if find(a) != find(b):
union(a,b)
total += cost
- 먼저 가장 짧은 간선부터 추가할 것이므로,
edges.sort()
를 통해 정렬한다. - 그 후, 부모가 다를 떄만
union(a,b)
를 통해 둘을 묶는다. - 간선을 추가함과 동시에 해당 가중치를 추가한다.
노드 접근법은 다음과 같다. 이번에는 그래프를 선언하여 정보를 기입한다.
v,e = map(int,input().split())
graph = [[] for _ in range(v+1)]
for _ in range(e):
a,b,c = map(int,input().split())
graph[a].append([c,a,b])
graph[b].append([c,b,a])
그 후, 시작 지점을 정하고 해당 시작점부터 시작하여 가장 작은 간선을 선택하기 시작한다.
def Prim(start_node):
total_weight = 0
s_edges = graph[start_node]
visited = [0 for _ in range(v+1)]
visited[start_node] = 1
heapq.heapify(s_edges)
- s_edges는 시작점과 연결된 모든 간선의 배열을 의미한다.
- 가장 작은 간선부터 선택할 것이므로,
heapq.heapify(s_edges)
를 통해 간선의 배열을minHeap
으로 변환한다.
def Prim(start_node):
...
while s_edges:
wei, start, end = heapq.heappop(s_edges)
if visited[end] == 0:
visited[end] = 1
total_weight += wei
for n_edge in graph[end]:
if visited[n_edge[2]] == 0:
heapq.heappush(s_edges,n_edge)
return total_weight
그 후, heap
이 Empty가 될 때 까지, 계속해서 작은 간선을 먼저 택하여 새로운 노드를 방문하고, 해당 노드에 연결된 모든 간선을 heap
에 추가한 후, node의 방문상태를 true로 갱신한다.
Kruskal의 시간 복잡도는 O(ElogE), Prim의 시간복잡도는 O(ElogV)다. 따라서 만일 E > V면 Prim 알고리즘을, V > E면 Kruskal 알고리즘을 선택하는 것이 유리하다.
위 문제의 경우는 E > V이기 때문에 웬만하면 Prim 알고리즘을 선택하는 것이 훨씬 유리하다. 따라서, 이런 문제를 풀 때는 E,V의 범위를 살펴보는 것이 좋을 것 같다.
Kruskal
import sys
sys.setrecursionlimit(10**6)
v,e = map(int, input().split())
edges = [[] for _ in range(e)]
parent = [i for i in range(v+1)]
rank = [0 for _ in range(v+1)]
def find(a):
if a == parent[a]:
return a
pa = find(parent[a])
parent[a] = pa
return parent[a]
def union(a,b):
a = find(a)
b = find(b)
if a == b:
return
if rank[a] > rank[b]:
parent[b] = a
else:
parent[a] = b
if rank[a] == rank[b]:
rank[b] += 1
for i in range(e):
a,b,c = map(int,input().split())
edges[i] = [c,a,b]
edges.sort()
total = 0
for edge in edges:
if not edge:
continue
cost, a, b = edge
if find(a) != find(b):
union(a,b)
total += cost
print(total)
Prim
import heapq
import sys
INF = sys.maxsize
v,e = map(int,input().split())
graph = [[] for _ in range(v+1)]
for _ in range(e):
a,b,c = map(int,input().split())
graph[a].append([c,a,b]) #비용, 시작, 도착
graph[b].append([c,b,a])
def Prim(start_node): #시작 노드부터 차래로 탐색해 나가므로, 시작 노드를 정해주어야 한다.
total_weight = 0
s_edges = graph[start_node] #시작노드의 모든 간선들을 불러옴
visited = [0 for _ in range(v+1)]
visited[start_node] = 1
heapq.heapify(s_edges) #첫 시작 노드의 간선들을 우선순위 큐에 넣어준다.
while s_edges:
wei, start, end = heapq.heappop(s_edges)
if visited[end] == 0:
visited[end] = 1
total_weight += wei
for n_edge in graph[end]: #노드를 새로 방문했으므로, 정보가 있는 간선들을 우선순위 큐에 추가함
if visited[n_edge[2]] == 0: #완전 새로운 간선일 경우에는 새로운 간선이므로 추가
heapq.heappush(s_edges,n_edge)
return total_weight
print(Prim(1))
이전에 풀었던 문제여서 손쉽게 풀었다.
전형적인 MST를 구하는 문제다.
따라서 maps
변수를 통해 graph를 구현했다.
그 뒤엔 Prim 알고리즘을 적용했다.
시간 복잡도는 O((E+V)log(V))
다.
이 알고리즘이 Kruskal 알고리즘인줄 알고 있었지만 사실 Prim 알고리즘이였다.
import heapq
import sys
read = sys.stdin.readline
if __name__ == "__main__":
V, E = map(int, read().split())
maps = [[] for _ in range(V)]
for _ in range(E):
A, B, C = map(int, read().split())
maps[A - 1].append((B - 1, C))
maps[B - 1].append((A - 1, C))
cnt = 1
results = 0
visited = [False for _ in range(V)]
paths = [(0, 0)]
while paths and cnt <= V:
weights, from_ = heapq.heappop(paths)
if visited[from_]:
continue
visited[from_] = True
cnt += 1
results += weights
for to_, weights in maps[from_]:
if not visited[to_]:
heapq.heappush(paths, (weights, to_))
print(results)
최소 스패닝 트리를 만들면 되겠구나 바로 find union 함수 구현할 생각을 하였음 크루스칼!
- 주어진 노드 node의 루트 노드를 찾는 함수
- 재귀적으로 부모 노드를 따라가며 루트 노드를 찾음
- 이때 부모 노드를 갱신하여 경로 압축(Path Compression) 수행
- 경로 압축은 루트를 찾은 후에 해당 루트를 바로 가리키도록 부모 노드를 갱신하는 기법
- 주어진 두 노드 x와 y를 포함하는 두 트리를 합치는 함수
- 이 함수는 두 노드의 루트를 찾은 후에 하나의 루트를 다른 루트의 자식으로 설정하여 두 트리를 합침
- 메인 함수는 입력을 받고, 간선들을 가중치에 따라 정렬한 뒤 Kruskal 알고리즘을 이용하여 최소 스패닝 트리의 가중치 합을 계산하고 출력
일반적인 최소 신장 트리는 union find 함수를 잘 구현한 다음에 그래프를 잘 돌리면 되는 것 같당
import sys
def find(node, parent):
if node != parent[node]:
parent[node] = find(parent[node], parent)
return parent[node]
def union(x, y, parent):
x_root = find(x, parent)
y_root = find(y, parent)
parent[y_root] = x_root
def main():
v, e = map(int, sys.stdin.readline().split())
edges = [list(map(int, sys.stdin.readline().split())) for _ in range(e)]
edges.sort(key=lambda x: x[2])
parent = list(range(v + 1))
total_weight = 0
for start, end, weight in edges:
if find(start, parent) != find(end, parent):
union(start, end, parent)
total_weight += weight
print(total_weight)
if __name__ == "__main__":
main()