- Notifications
You must be signed in to change notification settings - Fork 228
/
Copy pathplot_sandwich.py
105 lines (83 loc) · 3.02 KB
/
plot_sandwich.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
"""
Sandwich demo
=============
Sandwich demo based on code from http://nbviewer.ipython.org/6576096
"""
######################################################################
# .. note::
#
# In order to show the charts of the examples you need a graphical
# ``matplotlib`` backend installed. For intance, use ``pip install pyqt5``
# to get Qt graphical interface or use your favorite one.
importnumpyasnp
frommatplotlibimportpyplotasplt
fromsklearn.metricsimportpairwise_distances
fromsklearn.neighborsimportNearestNeighbors
frommetric_learnimport (LMNN, ITML_Supervised, LSML_Supervised,
SDML_Supervised)
defsandwich_demo():
x, y=sandwich_data()
knn=nearest_neighbors(x, k=2)
ax=plt.subplot(3, 1, 1) # take the whole top row
plot_sandwich_data(x, y, ax)
plot_neighborhood_graph(x, knn, y, ax)
ax.set_title('input space')
ax.set_aspect('equal')
ax.set_xticks([])
ax.set_yticks([])
mls= [
LMNN(),
ITML_Supervised(n_constraints=200),
SDML_Supervised(n_constraints=200, balance_param=0.001),
LSML_Supervised(n_constraints=200),
]
forax_num, mlinenumerate(mls, start=3):
ml.fit(x, y)
tx=ml.transform(x)
ml_knn=nearest_neighbors(tx, k=2)
ax=plt.subplot(3, 2, ax_num)
plot_sandwich_data(tx, y, axis=ax)
plot_neighborhood_graph(tx, ml_knn, y, axis=ax)
ax.set_title(ml.__class__.__name__)
ax.set_xticks([])
ax.set_yticks([])
plt.show()
# TODO: use this somewhere
defvisualize_class_separation(X, labels):
_, (ax1, ax2) =plt.subplots(ncols=2)
label_order=np.argsort(labels)
ax1.imshow(pairwise_distances(X[label_order]), interpolation='nearest')
ax2.imshow(pairwise_distances(labels[label_order, None]),
interpolation='nearest')
defnearest_neighbors(X, k=5):
knn=NearestNeighbors(n_neighbors=k)
knn.fit(X)
returnknn.kneighbors(X, return_distance=False)
defsandwich_data():
# number of distinct classes
num_classes=6
# number of points per class
num_points=9
# distance between layers, the points of each class are in a layer
dist=0.7
data=np.zeros((num_classes, num_points, 2), dtype=float)
labels=np.zeros((num_classes, num_points), dtype=int)
x_centers=np.arange(num_points, dtype=float) -num_points/2
y_centers=dist* (np.arange(num_classes, dtype=float) -num_classes/2)
fori, ycinenumerate(y_centers):
fork, xcinenumerate(x_centers):
data[i, k, 0] =np.random.normal(xc, 0.1)
data[i, k, 1] =np.random.normal(yc, 0.1)
labels[i, :] =i
returndata.reshape((-1, 2)), labels.ravel()
defplot_sandwich_data(x, y, axis=plt, colors='rbgmky'):
foridx, valinenumerate(np.unique(y)):
xi=x[y==val]
axis.scatter(*xi.T, s=50, facecolors='none', edgecolors=colors[idx])
defplot_neighborhood_graph(x, nn, y, axis=plt, colors='rbgmky'):
fori, ainenumerate(x):
b=x[nn[i, 1]]
axis.plot((a[0], b[0]), (a[1], b[1]), colors[y[i]])
if__name__=='__main__':
sandwich_demo()