- Notifications
You must be signed in to change notification settings - Fork 458
/
Copy pathattack.py
712 lines (562 loc) · 21.2 KB
/
attack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This program solves the NeuraCrypt challenge to 100% accuracy.
# Given a set of encoded images and original versions of those,
# it shows how to match the original to the encoded.
importcollections
importhashlib
importtime
importmultiprocessingasmp
importtorch
importnumpyasnp
importtorch.nnasnn
importscipy.stats
importmatplotlib.pyplotasplt
fromPILimportImage
importjax
importjax.numpyasjn
importobjax
importscipy.optimize
importnumpyasnp
importmultiprocessingasmp
# Objax neural network that's going to embed patches to a
# low dimensional space to guess if two patches correspond
# to the same orginal image.
classModel(objax.Module):
def__init__(self):
IN=15
H=64
self.encoder=objax.nn.Sequential([
objax.nn.Linear(IN, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, 8)])
self.decoder=objax.nn.Sequential([
objax.nn.Linear(IN, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, H),
objax.functional.leaky_relu,
objax.nn.Linear(H, 8)])
self.scale=objax.nn.Linear(1, 1, use_bias=False)
defencode(self, x):
# Encode turns original images into feature space
a=self.encoder(x)
a=a/jn.sum(a**2,axis=-1,keepdims=True)**.5
returna
defdecode(self, x):
# And decode turns encoded images into feature space
a=self.decoder(x)
a=a/jn.sum(a**2,axis=-1,keepdims=True)**.5
returna
# Proxy dataset for analysis
classImageNet:
num_chan=3
private_kernel_size=16
hidden_dim=2048
img_size= (256, 256)
private_depth=7
def__init__(self, remove):
self.remove_pixel_shuffle=remove
# Original dataset as used in the NeuraCrypt paper
classXray:
num_chan=1
private_kernel_size=16
hidden_dim=2048
img_size= (256, 256)
private_depth=4
def__init__(self, remove):
self.remove_pixel_shuffle=remove
## The following class is taken directly from the NeuraCrypt codebase.
## https://github.com/yala/NeuraCrypt
## which is originally licensed under the MIT License
classPrivateEncoder(nn.Module):
def__init__(self, args, width_factor=1):
super(PrivateEncoder, self).__init__()
self.args=args
input_dim=args.num_chan
patch_size=args.private_kernel_size
output_dim=args.hidden_dim
num_patches= (args.img_size[0] //patch_size) **2
self.noise_size=1
args.input_dim=args.hidden_dim
layers= [
nn.Conv2d(input_dim, output_dim*width_factor, kernel_size=patch_size, dilation=1 ,stride=patch_size),
nn.ReLU()
]
for_inrange(self.args.private_depth):
layers.extend( [
nn.Conv2d(output_dim*width_factor, output_dim*width_factor , kernel_size=1, dilation=1, stride=1),
nn.BatchNorm2d(output_dim*width_factor, track_running_stats=False),
nn.ReLU()
])
self.image_encoder=nn.Sequential(*layers)
self.pos_embedding=nn.Parameter(torch.randn(1, num_patches, output_dim*width_factor))
self.mixer=nn.Sequential( *[
nn.ReLU(),
nn.Linear(output_dim*width_factor, output_dim)
])
defforward(self, x):
encoded=self.image_encoder(x)
B, C, H,W=encoded.size()
encoded=encoded.view([B, -1, H*W]).transpose(1,2)
encoded+=self.pos_embedding
encoded=self.mixer(encoded)
## Shuffle indicies
ifnotself.args.remove_pixel_shuffle:
shuffled=torch.zeros_like(encoded)
foriinrange(B):
idx=torch.randperm(H*W, device=encoded.device)
forj, kinenumerate(idx):
shuffled[i,j] =encoded[i,k]
encoded=shuffled
returnencoded
## End copied code
defsetup(ds):
"""
Load the datasets to use. Nothing interesting to see.
"""
globalx_train, y_train
ifds=='imagenet':
importtorchvision
transform=torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(256),
torchvision.transforms.ToTensor()])
imagenet_data=torchvision.datasets.ImageNet('/mnt/data/datasets/unpacked_imagenet_pytorch/',
split='val',
transform=transform)
data_loader=torch.utils.data.DataLoader(imagenet_data,
batch_size=100,
shuffle=True,
num_workers=8)
r= []
forx,_indata_loader:
iflen(r) >1000: break
print(x.shape)
r.extend(x.numpy())
x_train=np.array(r)
print(x_train.shape)
elifds=='xray':
importtorchvision
transform=torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
torchvision.transforms.CenterCrop(256),
torchvision.transforms.ToTensor()])
imagenet_data=torchvision.datasets.ImageFolder('CheXpert-v1.0/train',
transform=transform)
data_loader=torch.utils.data.DataLoader(imagenet_data,
batch_size=100,
shuffle=True,
num_workers=8)
r= []
forx,_indata_loader:
iflen(r) >1000: break
print(x.shape)
r.extend(x.numpy())
x_train=np.array(r)
print(x_train.shape)
elifds=='challenge':
x_train=np.load("orig-7.npy")
print(np.min(x_train), np.max(x_train), x_train.shape)
else:
raise
defgen_train_data():
"""
Generate aligned training data to train a patch similarity function.
Given some original images, generate lots of encoded versions.
"""
globalencoded_train, original_train
encoded_train= []
original_train= []
args=Xray(True)
C=100
foriinrange(30):
print(i)
torch.manual_seed(int(time.time()))
e=PrivateEncoder(args).cuda()
batch=np.random.randint(0, len(x_train), size=C)
xin=x_train[batch]
r= []
foriinrange(0,C,32):
r.extend(e(torch.tensor(xin[i:i+32]).cuda()).detach().cpu().numpy())
r=np.array(r)
encoded_train.append(r)
original_train.append(xin)
deffeatures_(x, moments=15, encoded=False):
"""
Compute higher-order moments for patches in an image to use as
features for the neural network.
"""
x=np.array(x, dtype=np.float32)
dim=2
arr=np.array([np.mean(x, dim)] + [abs(scipy.stats.moment(x, moment=i, axis=dim))**(1/i) foriinrange(1,moments)])
returnarr.transpose((1,2,0))
deffeatures(x, encoded):
"""
Given the original images or the encoded images, generate the
features to use for the patch similarity function.
"""
print('start shape',x.shape)
iflen(x.shape) ==3:
x=x-np.mean(x,axis=0,keepdims=True)
else:
# count x 100 x 256 x 768
print(x[0].shape)
x=x-np.mean(x,axis=1,keepdims=True)
# remove per-neural-network dimension
x=x.reshape((x.shape[0] *x.shape[1],) +x.shape[2:])
p=mp.Pool(96)
B=len(x) //96
print(1)
bs= [x[i:i+B] foriinrange(0,len(x),B)]
print(2)
r=p.map(features_, bs)
#r = features_(bs[0][:100])
print(3)
p.close()
#r = np.array(r)
#print('finish',r.shape)
returnnp.concatenate(r, axis=0)
defget_train_features():
"""
Create features for the entire datasets.
"""
globalxs_train, ys_train
print(x_train.shape)
original_train_=np.array(original_train)
encoded_train_=np.array(encoded_train)
print("Computing features")
ys_train=features(encoded_train_, True)
patch_size=16
ss=original_train_.shape[3] //patch_size
# Okay so this is an ugly transpose block.
# We are going from [outer_batch, batch_size, channels, width, height
# to [outer_batch, batch_size, channels, width/patch_size, patch_size, height/patch_size, patch_size]
# Then we reshape this and flatten so that we end up with
# [other_batch, batch_size, width/patch_size, height_patch_size, patch_size**2*channels]
# So that now we can run features on the last dimension
original_train_=original_train_.reshape((original_train_.shape[0],
original_train_.shape[1],
original_train_.shape[2],
ss,patch_size,ss,patch_size)).transpose((0,1,3,5,2,4,6)).reshape((original_train_.shape[0], original_train_.shape[1], ss**2, patch_size**2))
xs_train=features(original_train_, False)
print(xs_train.shape, ys_train.shape)
deftrain_model():
"""
Train the patch similarity function
"""
globalema, model
model=Model()
defloss(x, y):
"""
K-way contrastive loss as in SimCLR et al.
The idea is that we should embed x and y so that they are similar
to each other, and dis-similar from others. To do this we have a
softmx loss over one dimension to make the values large on the diagonal
and small off-diagonal.
"""
a=model.encode(x)
b=model.decode(y)
mat=a@b.T
returnobjax.functional.loss.cross_entropy_logits_sparse(
logits=jn.exp(jn.clip(model.scale.w.value, -2, 4)) *mat,
labels=np.arange(a.shape[0])).mean()
ema=objax.optimizer.ExponentialMovingAverage(model.vars(), momentum=0.999)
gv=objax.GradValues(loss, model.vars())
encode_ema=ema.replace_vars(lambdax: model.encode(x))
decode_ema=ema.replace_vars(lambday: model.decode(y))
deftrain_op(x, y):
"""
No one was ever fired for using Adam with 1e-4.
"""
g, v=gv(x, y)
opt(1e-4, g)
ema()
returnv
opt=objax.optimizer.Adam(model.vars())
train_op=objax.Jit(train_op, gv.vars() +opt.vars() +ema.vars())
ys_=ys_train
print(ys_.shape)
xs_=xs_train.reshape((-1, xs_train.shape[-1]))
ys_=ys_.reshape((-1, ys_train.shape[-1]))
# The model scale trick here is taken from CLIP.
# Let the model decide how confident to make its own predictions.
model.scale.w.assign(jn.zeros((1,1)))
valid_size=1000
print(xs_train.shape)
# SimCLR likes big batches
B=4096
foritinrange(80):
print()
ms= []
foriinrange(1000):
# First batch is smaller, to make training more stable
bs= [B//64, B][it>0]
batch=np.random.randint(0, len(xs_)-valid_size, size=bs)
r=train_op(xs_[batch], ys_[batch])
# This shouldn't happen, but if it does, better to bort early
ifnp.isnan(r):
print("Die on nan")
print(ms[-100:])
return
ms.append(r)
print('mean',np.mean(ms), 'scale', model.scale.w.value)
print('loss',loss(xs_[-100:], ys_[-100:]))
a=encode_ema(xs_[-valid_size:])
b=decode_ema(ys_[-valid_size:])
br=b[np.random.permutation(len(b))]
print('score',np.mean(np.sum(a*b,axis=(1)) -np.sum(a*br,axis=(1))),
np.mean(np.sum(a*b,axis=(1)) >np.sum(a*br,axis=(1))))
ckpt=objax.io.Checkpoint("saved", keep_ckpts=0)
ema.replace_vars(lambda: ckpt.save(model.vars(), 0))()
defload_challenge():
"""
Load the challenge datast for attacking
"""
globalxs, ys, encoded, original, ooriginal
print("SETUP: Loading matrixes")
# The encoded images
encoded=np.load("challenge-7.npy")
# And the original images
ooriginal=original=np.load("orig-7.npy")
print("Sizes", encoded.shape, ooriginal.shape)
# Again do that ugly resize thing to make the features be on the last dimension
# Look up above to see what's going on.
patch_size=16
ss=original.shape[2] //patch_size
original=ooriginal.reshape((original.shape[0],1,ss,patch_size,ss,patch_size))
original=original.transpose((0,2,4,1,3,5))
original=original.reshape((original.shape[0], ss**2, patch_size**2))
defmatch_sub(args):
"""
Find the best way to undo the permutation between two images.
"""
vec1, vec2=args
value=np.sum((vec1[None,:,:] -vec2[:,None,:])**2,axis=2)
row, col=scipy.optimize.linear_sum_assignment(value)
returncol
defrecover_local_permutation():
"""
Given a set of encoded images, return a new encoding without permutations
"""
globalencoded, ys
p=mp.Pool(96)
print('recover local')
local_perm=p.map(match_sub, [(encoded[0], e) foreinencoded])
local_perm=np.array(local_perm)
encoded_perm= []
foriinrange(len(encoded)):
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
encoded_perm=np.array(encoded_perm)
encoded=np.array(encoded_perm)
p.close()
defrecover_better_local_permutation():
"""
Given a set of encoded images, return a new encoding, but better!
"""
globalencoded, ys
# Now instead of pairing all images to image 0, we compute the mean l2 vector
# and then pair all images onto the mean vector. Slightly more noise resistant.
p=mp.Pool(96)
target=encoded.mean(0)
local_perm=p.map(match_sub, [(target, e) foreinencoded])
local_perm=np.array(local_perm)
# Probably we didn't change by much, generally <0.1%
print('improved changed by', np.mean(local_perm!=np.arange(local_perm.shape[1])))
encoded_perm= []
foriinrange(len(encoded)):
encoded_perm.append(encoded[i][np.argsort(local_perm[i])])
encoded=np.array(encoded_perm)
p.close()
defcompute_patch_similarity():
"""
Compute the feature vectors for each patch using the trained neural network.
"""
globalxs, ys, xs_image, ys_image
print("Computing features")
ys=features(encoded, encoded=True)
xs=features(original, encoded=False)
model=Model()
ckpt=objax.io.Checkpoint("saved", keep_ckpts=0)
ckpt.restore(model.vars())
xs_image=model.encode(xs)
ys_image=model.decode(ys)
assertxs.shape[0] ==xs_image.shape[0]
print("Done")
defmatch(args, ret_col=False):
"""
Compute the similarity between image features and encoded features.
"""
vec1, vec2s=args
r= []
open("/tmp/start%d.%d"%(np.random.randint(10000),time.time()),"w").write("hi")
forvec2invec2s:
value=np.sum(vec1[None,:,:] *vec2[:,None,:],axis=2)
row, col=scipy.optimize.linear_sum_assignment(-value)
r.append(value[row,col].mean())
returnr
defrecover_global_matching_first():
"""
Recover the global matching of original to encoded images by doing
an all-pairs matching problem
"""
globalglobal_matching, ys_image, encoded
matrix= []
p=mp.Pool(96)
xs_image_=np.array(xs_image)
ys_image_=np.array(ys_image)
matrix=p.map(match, [(x, ys_image_) forxinxs_image_])
matrix=np.array(matrix).reshape((xs_image.shape[0],
xs_image.shape[0]))
row, col=scipy.optimize.linear_sum_assignment(-np.array(matrix))
global_matching=np.argsort(col)
print('glob',list(global_matching))
p.close()
defrecover_global_permutation():
"""
Find the way that the encoded images are permuted off of the original images
"""
globalglobal_permutation
print("Glob match", global_matching)
overall= []
fori,jinenumerate(global_matching):
overall.append(np.sum(xs_image[j][None,:,:] *ys_image[i][:,None,:],axis=2))
overall=np.mean(overall, 0)
row, col=scipy.optimize.linear_sum_assignment(-overall)
try:
print("Changed frac:", np.mean(global_permutation!=np.argsort(col)))
except:
pass
global_permutation=np.argsort(col)
defrecover_global_matching_second():
"""
Match each encoded image with its original encoded image,
but better by relying on the global permutation.
"""
globalglobal_matching_second, global_matching
ys_fix= []
foriinrange(ys_image.shape[0]):
ys_fix.append(ys_image[i][global_permutation])
ys_fix=np.array(ys_fix)
print(xs_image.shape)
sims= []
foriinrange(0,len(xs_image),10):
tmp=np.mean(xs_image[None,:,:,:] *ys_fix[i:i+10][:,None,:,:],axis=(2,3))
sims.extend(tmp)
sims=np.array(sims)
print(sims.shape)
row, col=scipy.optimize.linear_sum_assignment(-sims)
print('arg',sims.argmax(1))
print("Same matching frac", np.mean(col==global_matching) )
print(col)
global_matching=col
defextract_by_training(resume):
"""
Final recovery process by extracting the neural network
"""
globalinverse
device=torch.device('cuda:1')
ifnotresume:
inverse=PrivateEncoder(Xray(True)).cuda(device)
# More adam to train.
optimizer=torch.optim.Adam(inverse.parameters(), lr=0.0001)
this_xs=ooriginal[global_matching]
this_ys=encoded[:,global_permutation,:]
foriinrange(2000):
idx=np.random.random_integers(0, len(this_xs)-1, 32)
xbatch=torch.tensor(this_xs[idx]).cuda(device)
ybatch=torch.tensor(this_ys[idx]).cuda(device)
optimizer.zero_grad()
guess_output=inverse(xbatch)
# L1 loss because we don't want to be sensitive to outliers
error=torch.mean(torch.abs(guess_output-ybatch))
error.backward()
optimizer.step()
print(error)
deftest_extract():
"""
Now we can recover the matching much better by computing the estimated
encodings for each original image.
"""
globalerr, global_matching, guessed_encoded, smatrix
device=torch.device('cuda:1')
print(ooriginal.shape, encoded.shape)
out= []
foriinrange(0,len(ooriginal),32):
print(i)
out.extend(inverse(torch.tensor(ooriginal[i:i+32]).cuda(device)).cpu().detach().numpy())
guessed_encoded=np.array(out)
# Now we have to compare each encoded image with every other original image.
# Do this fast with some matrix multiplies.
out=guessed_encoded.reshape((len(encoded), -1))
real=encoded[:,global_permutation,:].reshape((len(encoded), -1))
@jax.jit
deffoo(x, y):
returnjn.square(x[:,None] -y[None,:]).sum(2)
smatrix=np.zeros((len(out), len(out)))
B=500
foriinrange(0,len(out),B):
print(i)
forjinrange(0,len(out),B):
smatrix[i:i+B, j:j+B] =foo(out[i:i+B], real[j:j+B])
# And the final time you'l have to look at a min weight matching, I promise.
row, col=scipy.optimize.linear_sum_assignment(np.array(smatrix))
r=np.array(smatrix)
print(list(row)[::100])
print("Differences", np.mean(np.argsort(col) !=global_matching))
global_matching=np.argsort(col)
defperf(steps=[]):
iflen(steps) ==0:
steps.append(time.time())
else:
print("Last Time Elapsed:", time.time()-steps[-1], ' Total Time Elapsed:', time.time()-steps[0])
steps.append(time.time())
time.sleep(1)
if__name__=="__main__":
ifTrue:
perf()
setup('challenge')
perf()
gen_train_data()
perf()
get_train_features()
perf()
train_model()
perf()
ifTrue:
load_challenge()
perf()
recover_local_permutation()
perf()
recover_better_local_permutation()
perf()
compute_patch_similarity()
perf()
recover_global_matching_first()
perf()
for_inrange(3):
recover_global_permutation()
perf()
recover_global_matching_second()
perf()
foriinrange(3):
recover_global_permutation()
perf()
extract_by_training(i>0)
perf()
test_extract()
perf()
print(perf())