- Notifications
You must be signed in to change notification settings - Fork 46.7k
/
Copy pathminimum_spanning_tree_kruskal2.py
121 lines (101 loc) · 4 KB
/
minimum_spanning_tree_kruskal2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ importannotations
fromtypingimportGeneric, TypeVar
T=TypeVar("T")
classDisjointSetTreeNode(Generic[T]):
# Disjoint Set Node to store the parent and rank
def__init__(self, data: T) ->None:
self.data=data
self.parent=self
self.rank=0
classDisjointSetTree(Generic[T]):
# Disjoint Set DataStructure
def__init__(self) ->None:
# map from node name to the node object
self.map: dict[T, DisjointSetTreeNode[T]] = {}
defmake_set(self, data: T) ->None:
# create a new set with x as its member
self.map[data] =DisjointSetTreeNode(data)
deffind_set(self, data: T) ->DisjointSetTreeNode[T]:
# find the set x belongs to (with path-compression)
elem_ref=self.map[data]
ifelem_ref!=elem_ref.parent:
elem_ref.parent=self.find_set(elem_ref.parent.data)
returnelem_ref.parent
deflink(
self, node1: DisjointSetTreeNode[T], node2: DisjointSetTreeNode[T]
) ->None:
# helper function for union operation
ifnode1.rank>node2.rank:
node2.parent=node1
else:
node1.parent=node2
ifnode1.rank==node2.rank:
node2.rank+=1
defunion(self, data1: T, data2: T) ->None:
# merge 2 disjoint sets
self.link(self.find_set(data1), self.find_set(data2))
classGraphUndirectedWeighted(Generic[T]):
def__init__(self) ->None:
# connections: map from the node to the neighbouring nodes (with weights)
self.connections: dict[T, dict[T, int]] = {}
defadd_node(self, node: T) ->None:
# add a node ONLY if its not present in the graph
ifnodenotinself.connections:
self.connections[node] = {}
defadd_edge(self, node1: T, node2: T, weight: int) ->None:
# add an edge with the given weight
self.add_node(node1)
self.add_node(node2)
self.connections[node1][node2] =weight
self.connections[node2][node1] =weight
defkruskal(self) ->GraphUndirectedWeighted[T]:
# Kruskal's Algorithm to generate a Minimum Spanning Tree (MST) of a graph
"""
Details: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
Example:
>>> g1 = GraphUndirectedWeighted[int]()
>>> g1.add_edge(1, 2, 1)
>>> g1.add_edge(2, 3, 2)
>>> g1.add_edge(3, 4, 1)
>>> g1.add_edge(3, 5, 100) # Removed in MST
>>> g1.add_edge(4, 5, 5)
>>> assert 5 in g1.connections[3]
>>> mst = g1.kruskal()
>>> assert 5 not in mst.connections[3]
>>> g2 = GraphUndirectedWeighted[str]()
>>> g2.add_edge('A', 'B', 1)
>>> g2.add_edge('B', 'C', 2)
>>> g2.add_edge('C', 'D', 1)
>>> g2.add_edge('C', 'E', 100) # Removed in MST
>>> g2.add_edge('D', 'E', 5)
>>> assert 'E' in g2.connections["C"]
>>> mst = g2.kruskal()
>>> assert 'E' not in mst.connections['C']
"""
# getting the edges in ascending order of weights
edges= []
seen=set()
forstartinself.connections:
forendinself.connections[start]:
if (start, end) notinseen:
seen.add((end, start))
edges.append((start, end, self.connections[start][end]))
edges.sort(key=lambdax: x[2])
# creating the disjoint set
disjoint_set=DisjointSetTree[T]()
fornodeinself.connections:
disjoint_set.make_set(node)
# MST generation
num_edges=0
index=0
graph=GraphUndirectedWeighted[T]()
whilenum_edges<len(self.connections) -1:
u, v, w=edges[index]
index+=1
parent_u=disjoint_set.find_set(u)
parent_v=disjoint_set.find_set(v)
ifparent_u!=parent_v:
num_edges+=1
graph.add_edge(u, v, w)
disjoint_set.union(u, v)
returngraph