- Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinteractive.py
170 lines (153 loc) · 6.64 KB
/
interactive.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
fromtensorflowimportkeras
fromtensorflow.kerasimportSequential
fromtensorflow.keras.layersimportDense
frommodel_functionsimportload_data, build_model
fromtypingimportList, NoReturn, Tuple
importnumpyasnp
importrandom
importpygame
importos
WHITE= (255, 255, 255)
BLACK= (0, 0, 0)
D_BLUE= (0, 32, 96)
WIDTH, HEIGHT=780, 380
defload_model() ->keras.Sequential:
model=None
# Try loading model structure and weights from dir
try:
model=keras.models.load_model('model')
except (OSError, AttributeError):
print(f"Encountered ERROR while loading model.\n"
f"Building model from saved weights")
ifmodelisnotNone:
print("Loaded model from directory '/model'.")
returnmodel
# If above fails try building model layers then loading ONLY weights
loading_weights_error=False
model=Sequential(layers=[Dense(units=784, activation='sigmoid', input_dim=784),
Dense(units=500, activation='sigmoid'),
Dense(units=10, activation='sigmoid')])
try:
model.load_weights("model_weights/cp.cpkt")
except (ImportError, ValueError):
loading_weights_error=True
print("Encountered ERROR while loading weights.\n"
"Ensure module h5py is installed and directory to weights is correct.")
ifloading_weights_errorisFalse:
print("Created model layers and loaded weights.")
returnmodel
# If all above fails then train a model, store it, and return it
print("Loading model and loading weights failed. Proceeding to\n"
"build a model, store it and it's weights in /model and /model_weights.")
iflen(os.listdir('/mnist_data')) !=2:
raiseTypeError("Cannot execute above. mnist_data folder does not contain training and test data.")
returnbuild_model()
defcreate_text(arr:List[str], font_size:int) ->List[pygame.font.SysFont]:
rv= []
font=pygame.font.SysFont('chalkduster.tff', font_size)
forsinarr:
rv.append(font.render(s, True, BLACK))
returnrv
classWindow:
def__init__(self):
pygame.init()
self.screen=pygame.display.set_mode((WIDTH, HEIGHT))
self.run=True
self.clock=pygame.time.Clock()
# non pygame attrs
self.pixels= []
self.pixels_out= []
self.text=create_text(["Input to Auto Encoder",
"Auto Encoder Output",
"[Key D: Pass Input to Auto Encoder]",
"[Key T: Load Random Image]",
"[Key C: Clear Input]"],
16)
self.model=load_model()
self.x_test, _, _, _=load_data()
# continuous loop
self.render()
defclear_screen(self) ->NoReturn:
self.pixels= [] #[[255 for _ in range(28)] for _ in range(28)]
self.pixels_out= []
defrandom_image(self) ->NoReturn:
self.clear_screen()
x_train, _, x_test, _=load_data()
i=random.randint(0, len(self.x_test) -1)
forvecinself.x_test[i]:
self.pixels.append(list(np.array(255- (vec*255), dtype="int16")))
returni
defquery_ae(self, i=0, mnist=True):
ifnotself.pixels: return
# load random image from mnist dataset
ifmnist:
x_train, _, x_test, _=load_data()
ae_out=self.model.predict([self.x_test[i].reshape(-1, 28, 28, 1)])[0]
# preprocess user drawing
else:
ae_out=self.model.predict(((255.0-np.array(self.pixels, dtype="float32")
)/255.0).reshape(-1, 28, 28, 1))[0]
# process ae_out to be rendered
forvecinae_out:
fori, ninenumerate(vec):
ifn>1: vec[i] =1
self.pixels_out= []
fory,vecinenumerate(ae_out):
self.pixels_out.append([])
vec=list(np.array(255- (vec*255), dtype="int16"))
forx,ninenumerate(vec):
self.pixels_out[y].append(int(n[0]))
defdraw_pixels(self, pixels:List[List[int]], x_offset:int, y_offset:int) ->NoReturn:
fory, vecinenumerate(pixels):
forx, pinenumerate(vec):
ifp>255: vec[x] =255
elifp<0: vec[x] =0
p=vec[x]
pygame.draw.rect(self.screen, (p, p, p),
[x_offset+ (x*10),
y_offset+ (y*10), 10, 10])
defdraw_text(self, coords:List[Tuple[int, int]]) ->NoReturn:
fortext_obj, xy_pairinzip(self.text, coords):
self.screen.blit(text_obj, xy_pair)
defrender(self) ->NoReturn:
whileself.run:
foreventinpygame.event.get():
ifevent.type==pygame.KEYDOWN:
ifevent.key==pygame.K_q:
self.run=False
ifevent.key==pygame.K_t:
i=self.random_image()
self.query_ae(i)
ifevent.key==pygame.K_d:
self.query_ae(mnist=False)
ifevent.key==pygame.K_c:
self.clear_screen()
self.screen.fill(WHITE)
# --[render start]--
self.draw_text([(110,20), (530, 20), (50, 330), (50, 345), (50, 360)])
self.draw_pixels(self.pixels, 40, 40) # input section
self.draw_pixels(self.pixels_out, 450, 40) # output section
pygame.draw.rect(self.screen, D_BLUE, [40, 37, 283, 285], 5) # input border
pygame.draw.rect(self.screen, D_BLUE, [450, 37, 283, 285], 5) # output border
# Handle mouse input (for drawing)
ifpygame.mouse.get_pressed(3)[0]:
x, y=pygame.mouse.get_pos()
x-=40
y-=40
# Checks if mouse is in drawing square
ifnot(x>=270orx<0ory>=270ory<=0):
if (o:=x%10) <5: x-=o
else: x+= (10-o)
if (o:=y%10) <5: y-=o
else: y+= (10-o)
try:
if (self.pixels[y//10][x//10]) !=0:
self.pixels[y//10][x//10] -=51
exceptIndexError:
self.pixels= [[255for_inrange(28)] for_inrange(28)]
# --[render end]--
pygame.display.flip()
self.clock.tick(144)
pygame.quit()
if__name__=='__main__':
Window()