Home 프림(Prim) 알고리즘(기본 예제)
Post
Cancel

프림(Prim) 알고리즘(기본 예제)

기초 강의 바로가기

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from heapq import *
from collections import defaultdict
graph = [
    (7, 'A', 'B'), (5, 'A', 'D'),
    (8, 'B', 'C'), (8, 'B', 'D'), (7, 'B', 'E'),
    (5, 'C', 'E'),
    (7, 'D', 'E'), (6, 'D', 'F'),
    (8, 'E', 'F'), (9, 'E', 'G'),
    (11, 'F', 'G')
]

def prim(start_node, edges):
    mst = list()
    adj_edges = defaultdict(list)

    for weight, n1, n2 in edges:
        adj_edges[n1].append((weight, n1, n2))
        adj_edges[n2].append((weight, n2, n1))

    connected_nodes = set(start_node)
    candidate_edge_list = adj_edges[start_node]
    heapify(candidate_edge_list)

    while candidate_edge_list:
        weight, n1, n2 = heappop(candidate_edge_list)
        if n2 not in connected_nodes:
            connected_nodes.add(n2)
            mst.append((weight, n1, n2))

            for edge in adj_edges[n2]:
                if edge[2] not in connected_nodes:
                    heappush(candidate_edge_list, edge)
    return mst


print(prim('A', graph))
This post is licensed under CC BY 4.0 by the author.