- Notifications
You must be signed in to change notification settings - Fork 1.8k
/
Copy pathkmeans.py
150 lines (116 loc) · 5.03 KB
/
kmeans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# coding:utf-8
importrandom
importmatplotlib.pyplotasplt
importnumpyasnp
importseabornassns
frommla.baseimportBaseEstimator
frommla.metrics.distanceimporteuclidean_distance
random.seed(1111)
classKMeans(BaseEstimator):
"""Partition a dataset into K clusters.
Finds clusters by repeatedly assigning each data point to the cluster with
the nearest centroid and iterating until the assignments converge (meaning
they don't change during an iteration) or the maximum number of iterations
is reached.
Parameters
----------
K : int
The number of clusters into which the dataset is partitioned.
max_iters: int
The maximum iterations of assigning points to the nearest cluster.
Short-circuited by the assignments converging on their own.
init: str, default 'random'
The name of the method used to initialize the first clustering.
'random' - Randomly select values from the dataset as the K centroids.
'++' - Select a random first centroid from the dataset, then select
K - 1 more centroids by choosing values from the dataset with a
probability distribution proportional to the squared distance
from each point's closest existing cluster. Attempts to create
larger distances between initial clusters to improve convergence
rates and avoid degenerate cases.
"""
y_required=False
def__init__(self, K=5, max_iters=100, init="random"):
self.K=K
self.max_iters=max_iters
self.clusters= [[] for_inrange(self.K)]
self.centroids= []
self.init=init
def_initialize_centroids(self, init):
"""Set the initial centroids."""
ifinit=="random":
self.centroids= [self.X[x] forxinrandom.sample(range(self.n_samples), self.K)]
elifinit=="++":
self.centroids= [random.choice(self.X)]
whilelen(self.centroids) <self.K:
self.centroids.append(self._choose_next_center())
else:
raiseValueError("Unknown type of init parameter")
def_predict(self, X=None):
"""Perform clustering on the dataset."""
self._initialize_centroids(self.init)
centroids=self.centroids
# Optimize clusters
for_inrange(self.max_iters):
self._assign(centroids)
centroids_old=centroids
centroids= [self._get_centroid(cluster) forclusterinself.clusters]
ifself._is_converged(centroids_old, centroids):
break
self.centroids=centroids
returnself._get_predictions()
def_get_predictions(self):
predictions=np.empty(self.n_samples)
fori, clusterinenumerate(self.clusters):
forindexincluster:
predictions[index] =i
returnpredictions
def_assign(self, centroids):
forrowinrange(self.n_samples):
fori, clusterinenumerate(self.clusters):
ifrowincluster:
self.clusters[i].remove(row)
break
closest=self._closest(row, centroids)
self.clusters[closest].append(row)
def_closest(self, fpoint, centroids):
"""Find the closest centroid for a point."""
closest_index=None
closest_distance=None
fori, pointinenumerate(centroids):
dist=euclidean_distance(self.X[fpoint], point)
ifclosest_indexisNoneordist<closest_distance:
closest_index=i
closest_distance=dist
returnclosest_index
def_get_centroid(self, cluster):
"""Get values by indices and take the mean."""
return [np.mean(np.take(self.X[:, i], cluster)) foriinrange(self.n_features)]
def_dist_from_centers(self):
"""Calculate distance from centers."""
returnnp.array([min([euclidean_distance(x, c) forcinself.centroids]) forxinself.X])
def_choose_next_center(self):
distances=self._dist_from_centers()
squared_distances=distances**2
probs=squared_distances/squared_distances.sum()
ind=np.random.choice(self.X.shape[0], 1, p=probs)[0]
returnself.X[ind]
def_is_converged(self, centroids_old, centroids):
"""Check if the distance between old and new centroids is zero."""
distance=0
foriinrange(self.K):
distance+=euclidean_distance(centroids_old[i], centroids[i])
returndistance==0
defplot(self, ax=None, holdon=False):
sns.set(style="white")
palette=sns.color_palette("hls", self.K+1)
data=self.X
ifaxisNone:
_, ax=plt.subplots()
fori, indexinenumerate(self.clusters):
point=np.array(data[index]).T
ax.scatter(*point, c=[palette[i], ])
forpointinself.centroids:
ax.scatter(*point, marker="x", linewidths=10)
ifnotholdon:
plt.show()