forked from TheAlgorithms/Python
- Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathself_organizing_map.py
73 lines (56 loc) · 2.03 KB
/
self_organizing_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
https://en.wikipedia.org/wiki/Self-organizing_map
"""
importmath
classSelfOrganizingMap:
defget_winner(self, weights: list[list[float]], sample: list[int]) ->int:
"""
Compute the winning vector by Euclidean distance
>>> SelfOrganizingMap().get_winner([[1, 2, 3], [4, 5, 6]], [1, 2, 3])
1
"""
d0=0.0
d1=0.0
foriinrange(len(sample)):
d0+=math.pow((sample[i] -weights[0][i]), 2)
d1+=math.pow((sample[i] -weights[1][i]), 2)
return0ifd0>d1else1
return0
defupdate(
self, weights: list[list[int|float]], sample: list[int], j: int, alpha: float
) ->list[list[int|float]]:
"""
Update the winning vector.
>>> SelfOrganizingMap().update([[1, 2, 3], [4, 5, 6]], [1, 2, 3], 1, 0.1)
[[1, 2, 3], [3.7, 4.7, 6]]
"""
foriinrange(len(weights)):
weights[j][i] +=alpha* (sample[i] -weights[j][i])
returnweights
# Driver code
defmain() ->None:
# Training Examples ( m, n )
training_samples= [[1, 1, 0, 0], [0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 1, 1]]
# weight initialization ( n, C )
weights= [[0.2, 0.6, 0.5, 0.9], [0.8, 0.4, 0.7, 0.3]]
# training
self_organizing_map=SelfOrganizingMap()
epochs=3
alpha=0.5
for_inrange(epochs):
forjinrange(len(training_samples)):
# training sample
sample=training_samples[j]
# Compute the winning vector
winner=self_organizing_map.get_winner(weights, sample)
# Update the winning vector
weights=self_organizing_map.update(weights, sample, winner, alpha)
# classify test sample
sample= [0, 0, 0, 1]
winner=self_organizing_map.get_winner(weights, sample)
# results
print(f"Clusters that the test sample belongs to : {winner}")
print(f"Weights that have been trained : {weights}")
# running the main() function
if__name__=="__main__":
main()