- Notifications
You must be signed in to change notification settings - Fork 46.7k
/
Copy pathmarkov_chain.py
84 lines (62 loc) · 2.04 KB
/
markov_chain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from __future__ importannotations
fromcollectionsimportCounter
fromrandomimportrandom
classMarkovChainGraphUndirectedUnweighted:
"""
Undirected Unweighted Graph for running Markov Chain Algorithm
"""
def__init__(self):
self.connections= {}
defadd_node(self, node: str) ->None:
self.connections[node] = {}
defadd_transition_probability(
self, node1: str, node2: str, probability: float
) ->None:
ifnode1notinself.connections:
self.add_node(node1)
ifnode2notinself.connections:
self.add_node(node2)
self.connections[node1][node2] =probability
defget_nodes(self) ->list[str]:
returnlist(self.connections)
deftransition(self, node: str) ->str:
current_probability=0
random_value=random()
fordestinself.connections[node]:
current_probability+=self.connections[node][dest]
ifcurrent_probability>random_value:
returndest
return""
defget_transitions(
start: str, transitions: list[tuple[str, str, float]], steps: int
) ->dict[str, int]:
"""
Running Markov Chain algorithm and calculating the number of times each node is
visited
>>> transitions = [
... ('a', 'a', 0.9),
... ('a', 'b', 0.075),
... ('a', 'c', 0.025),
... ('b', 'a', 0.15),
... ('b', 'b', 0.8),
... ('b', 'c', 0.05),
... ('c', 'a', 0.25),
... ('c', 'b', 0.25),
... ('c', 'c', 0.5)
... ]
>>> result = get_transitions('a', transitions, 5000)
>>> result['a'] > result['b'] > result['c']
True
"""
graph=MarkovChainGraphUndirectedUnweighted()
fornode1, node2, probabilityintransitions:
graph.add_transition_probability(node1, node2, probability)
visited=Counter(graph.get_nodes())
node=start
for_inrange(steps):
node=graph.transition(node)
visited[node] +=1
returnvisited
if__name__=="__main__":
importdoctest
doctest.testmod()