- Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathclassifier_free_guidance.py
853 lines (631 loc) · 28.9 KB
/
classifier_free_guidance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
importmath
importcopy
frompathlibimportPath
fromrandomimportrandom
fromfunctoolsimportpartial
fromcollectionsimportnamedtuple
frommultiprocessingimportcpu_count
importtorch
fromtorchimportnn, einsum
importtorch.nn.functionalasF
fromtorch.ampimportautocast
fromeinopsimportrearrange, reduce, repeat, pack, unpack
fromeinops.layers.torchimportRearrange
fromtqdm.autoimporttqdm
# constants
ModelPrediction=namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
# helpers functions
defexists(x):
returnxisnotNone
defdefault(val, d):
ifexists(val):
returnval
returnd() ifcallable(d) elsed
defidentity(t, *args, **kwargs):
returnt
defcycle(dl):
whileTrue:
fordataindl:
yielddata
defhas_int_squareroot(num):
return (math.sqrt(num) **2) ==num
defnum_to_groups(num, divisor):
groups=num//divisor
remainder=num%divisor
arr= [divisor] *groups
ifremainder>0:
arr.append(remainder)
returnarr
defconvert_image_to_fn(img_type, image):
ifimage.mode!=img_type:
returnimage.convert(img_type)
returnimage
defpack_one_with_inverse(x, pattern):
packed, packed_shape=pack([x], pattern)
definverse(x, inverse_pattern=None):
inverse_pattern=default(inverse_pattern, pattern)
returnunpack(x, packed_shape, inverse_pattern)[0]
returnpacked, inverse
# normalization functions
defnormalize_to_neg_one_to_one(img):
returnimg*2-1
defunnormalize_to_zero_to_one(t):
return (t+1) *0.5
# classifier free guidance functions
defuniform(shape, device):
returntorch.zeros(shape, device=device).float().uniform_(0, 1)
defprob_mask_like(shape, prob, device):
ifprob==1:
returntorch.ones(shape, device=device, dtype=torch.bool)
elifprob==0:
returntorch.zeros(shape, device=device, dtype=torch.bool)
else:
returntorch.zeros(shape, device=device).float().uniform_(0, 1) <prob
defproject(x, y):
x, inverse=pack_one_with_inverse(x, 'b *')
y, _=pack_one_with_inverse(y, 'b *')
dtype=x.dtype
x, y=x.double(), y.double()
unit=F.normalize(y, dim=-1)
parallel= (x*unit).sum(dim=-1, keepdim=True) *unit
orthogonal=x-parallel
returninverse(parallel).to(dtype), inverse(orthogonal).to(dtype)
# small helper modules
classResidual(nn.Module):
def__init__(self, fn):
super().__init__()
self.fn=fn
defforward(self, x, *args, **kwargs):
returnself.fn(x, *args, **kwargs) +x
defUpsample(dim, dim_out=None):
returnnn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(dim, default(dim_out, dim), 3, padding=1)
)
defDownsample(dim, dim_out=None):
returnnn.Conv2d(dim, default(dim_out, dim), 4, 2, 1)
classRMSNorm(nn.Module):
def__init__(self, dim):
super().__init__()
self.g=nn.Parameter(torch.ones(1, dim, 1, 1))
defforward(self, x):
returnF.normalize(x, dim=1) *self.g* (x.shape[1] **0.5)
classPreNorm(nn.Module):
def__init__(self, dim, fn):
super().__init__()
self.fn=fn
self.norm=RMSNorm(dim)
defforward(self, x):
x=self.norm(x)
returnself.fn(x)
# sinusoidal positional embeds
classSinusoidalPosEmb(nn.Module):
def__init__(self, dim):
super().__init__()
self.dim=dim
defforward(self, x):
device=x.device
half_dim=self.dim//2
emb=math.log(10000) / (half_dim-1)
emb=torch.exp(torch.arange(half_dim, device=device) *-emb)
emb=x[:, None] *emb[None, :]
emb=torch.cat((emb.sin(), emb.cos()), dim=-1)
returnemb
classRandomOrLearnedSinusoidalPosEmb(nn.Module):
""" following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """
""" https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
def__init__(self, dim, is_random=False):
super().__init__()
assert (dim%2) ==0
half_dim=dim//2
self.weights=nn.Parameter(torch.randn(half_dim), requires_grad=notis_random)
defforward(self, x):
x=rearrange(x, 'b -> b 1')
freqs=x*rearrange(self.weights, 'd -> 1 d') *2*math.pi
fouriered=torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered=torch.cat((x, fouriered), dim=-1)
returnfouriered
# building block modules
classBlock(nn.Module):
def__init__(self, dim, dim_out):
super().__init__()
self.proj=nn.Conv2d(dim, dim_out, 3, padding=1)
self.norm=RMSNorm(dim_out)
self.act=nn.SiLU()
defforward(self, x, scale_shift=None):
x=self.proj(x)
x=self.norm(x)
ifexists(scale_shift):
scale, shift=scale_shift
x=x* (scale+1) +shift
x=self.act(x)
returnx
classResnetBlock(nn.Module):
def__init__(self, dim, dim_out, *, time_emb_dim=None, classes_emb_dim=None):
super().__init__()
self.mlp=nn.Sequential(
nn.SiLU(),
nn.Linear(int(time_emb_dim) +int(classes_emb_dim), dim_out*2)
) ifexists(time_emb_dim) orexists(classes_emb_dim) elseNone
self.block1=Block(dim, dim_out)
self.block2=Block(dim_out, dim_out)
self.res_conv=nn.Conv2d(dim, dim_out, 1) ifdim!=dim_outelsenn.Identity()
defforward(self, x, time_emb=None, class_emb=None):
scale_shift=None
ifexists(self.mlp) and (exists(time_emb) orexists(class_emb)):
cond_emb=tuple(filter(exists, (time_emb, class_emb)))
cond_emb=torch.cat(cond_emb, dim=-1)
cond_emb=self.mlp(cond_emb)
cond_emb=rearrange(cond_emb, 'b c -> b c 1 1')
scale_shift=cond_emb.chunk(2, dim=1)
h=self.block1(x, scale_shift=scale_shift)
h=self.block2(h)
returnh+self.res_conv(x)
classLinearAttention(nn.Module):
def__init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale=dim_head**-0.5
self.heads=heads
hidden_dim=dim_head*heads
self.to_qkv=nn.Conv2d(dim, hidden_dim*3, 1, bias=False)
self.to_out=nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1),
RMSNorm(dim)
)
defforward(self, x):
b, c, h, w=x.shape
qkv=self.to_qkv(x).chunk(3, dim=1)
q, k, v=map(lambdat: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
q=q.softmax(dim=-2)
k=k.softmax(dim=-1)
q=q*self.scale
context=torch.einsum('b h d n, b h e n -> b h d e', k, v)
out=torch.einsum('b h d e, b h d n -> b h e n', context, q)
out=rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w)
returnself.to_out(out)
classAttention(nn.Module):
def__init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale=dim_head**-0.5
self.heads=heads
hidden_dim=dim_head*heads
self.to_qkv=nn.Conv2d(dim, hidden_dim*3, 1, bias=False)
self.to_out=nn.Conv2d(hidden_dim, dim, 1)
defforward(self, x):
b, c, h, w=x.shape
qkv=self.to_qkv(x).chunk(3, dim=1)
q, k, v=map(lambdat: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv)
q=q*self.scale
sim=einsum('b h d i, b h d j -> b h i j', q, k)
attn=sim.softmax(dim=-1)
out=einsum('b h i j, b h d j -> b h i d', attn, v)
out=rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
returnself.to_out(out)
# model
classUnet(nn.Module):
def__init__(
self,
dim,
num_classes,
cond_drop_prob=0.5,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
learned_variance=False,
learned_sinusoidal_cond=False,
random_fourier_features=False,
learned_sinusoidal_dim=16,
attn_dim_head=32,
attn_heads=4
):
super().__init__()
# classifier free guidance stuff
self.cond_drop_prob=cond_drop_prob
# determine dimensions
self.channels=channels
input_channels=channels
init_dim=default(init_dim, dim)
self.init_conv=nn.Conv2d(input_channels, init_dim, 7, padding=3)
dims= [init_dim, *map(lambdam: dim*m, dim_mults)]
in_out=list(zip(dims[:-1], dims[1:]))
# time embeddings
time_dim=dim*4
self.random_or_learned_sinusoidal_cond=learned_sinusoidal_condorrandom_fourier_features
ifself.random_or_learned_sinusoidal_cond:
sinu_pos_emb=RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features)
fourier_dim=learned_sinusoidal_dim+1
else:
sinu_pos_emb=SinusoidalPosEmb(dim)
fourier_dim=dim
self.time_mlp=nn.Sequential(
sinu_pos_emb,
nn.Linear(fourier_dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
# class embeddings
self.classes_emb=nn.Embedding(num_classes, dim)
self.null_classes_emb=nn.Parameter(torch.randn(dim))
classes_dim=dim*4
self.classes_mlp=nn.Sequential(
nn.Linear(dim, classes_dim),
nn.GELU(),
nn.Linear(classes_dim, classes_dim)
)
# layers
self.downs=nn.ModuleList([])
self.ups=nn.ModuleList([])
num_resolutions=len(in_out)
forind, (dim_in, dim_out) inenumerate(in_out):
is_last=ind>= (num_resolutions-1)
self.downs.append(nn.ModuleList([
ResnetBlock(dim_in, dim_in, time_emb_dim=time_dim, classes_emb_dim=classes_dim),
ResnetBlock(dim_in, dim_in, time_emb_dim=time_dim, classes_emb_dim=classes_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) ifnotis_lastelsenn.Conv2d(dim_in, dim_out, 3, padding=1)
]))
mid_dim=dims[-1]
self.mid_block1=ResnetBlock(mid_dim, mid_dim, time_emb_dim=time_dim, classes_emb_dim=classes_dim)
self.mid_attn=Residual(PreNorm(mid_dim, Attention(mid_dim, dim_head=attn_dim_head, heads=attn_heads)))
self.mid_block2=ResnetBlock(mid_dim, mid_dim, time_emb_dim=time_dim, classes_emb_dim=classes_dim)
forind, (dim_in, dim_out) inenumerate(reversed(in_out)):
is_last=ind== (len(in_out) -1)
self.ups.append(nn.ModuleList([
ResnetBlock(dim_out+dim_in, dim_out, time_emb_dim=time_dim, classes_emb_dim=classes_dim),
ResnetBlock(dim_out+dim_in, dim_out, time_emb_dim=time_dim, classes_emb_dim=classes_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) ifnotis_lastelsenn.Conv2d(dim_out, dim_in, 3, padding=1)
]))
default_out_dim=channels* (1ifnotlearned_varianceelse2)
self.out_dim=default(out_dim, default_out_dim)
self.final_res_block=ResnetBlock(init_dim*2, init_dim, time_emb_dim=time_dim, classes_emb_dim=classes_dim)
self.final_conv=nn.Conv2d(init_dim, self.out_dim, 1)
defforward_with_cond_scale(
self,
*args,
cond_scale=1.,
rescaled_phi=0.,
remove_parallel_component=True,
keep_parallel_frac=0.,
**kwargs
):
logits=self.forward(*args, cond_drop_prob=0., **kwargs)
ifcond_scale==1:
returnlogits
null_logits=self.forward(*args, cond_drop_prob=1., **kwargs)
update=logits-null_logits
ifremove_parallel_component:
parallel, orthog=project(update, logits)
update=orthog+parallel*keep_parallel_frac
scaled_logits=logits+update* (cond_scale-1.)
ifrescaled_phi==0.:
returnscaled_logits, null_logits
std_fn=partial(torch.std, dim=tuple(range(1, scaled_logits.ndim)), keepdim=True)
rescaled_logits=scaled_logits* (std_fn(logits) /std_fn(scaled_logits))
interpolated_rescaled_logits=rescaled_logits*rescaled_phi+scaled_logits* (1.-rescaled_phi)
returninterpolated_rescaled_logits, null_logits
defforward(
self,
x,
time,
classes,
cond_drop_prob=None
):
batch, device=x.shape[0], x.device
cond_drop_prob=default(cond_drop_prob, self.cond_drop_prob)
# derive condition, with condition dropout for classifier free guidance
classes_emb=self.classes_emb(classes)
ifcond_drop_prob>0:
keep_mask=prob_mask_like((batch,), 1-cond_drop_prob, device=device)
null_classes_emb=repeat(self.null_classes_emb, 'd -> b d', b=batch)
classes_emb=torch.where(
rearrange(keep_mask, 'b -> b 1'),
classes_emb,
null_classes_emb
)
c=self.classes_mlp(classes_emb)
# unet
x=self.init_conv(x)
r=x.clone()
t=self.time_mlp(time)
h= []
forblock1, block2, attn, downsampleinself.downs:
x=block1(x, t, c)
h.append(x)
x=block2(x, t, c)
x=attn(x)
h.append(x)
x=downsample(x)
x=self.mid_block1(x, t, c)
x=self.mid_attn(x)
x=self.mid_block2(x, t, c)
forblock1, block2, attn, upsampleinself.ups:
x=torch.cat((x, h.pop()), dim=1)
x=block1(x, t, c)
x=torch.cat((x, h.pop()), dim=1)
x=block2(x, t, c)
x=attn(x)
x=upsample(x)
x=torch.cat((x, r), dim=1)
x=self.final_res_block(x, t, c)
returnself.final_conv(x)
# gaussian diffusion trainer class
defextract(a, t, x_shape):
b, *_=t.shape
out=a.gather(-1, t)
returnout.reshape(b, *((1,) * (len(x_shape) -1)))
deflinear_beta_schedule(timesteps):
scale=1000/timesteps
beta_start=scale*0.0001
beta_end=scale*0.02
returntorch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
defcosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps=timesteps+1
x=torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod=torch.cos(((x/timesteps) +s) / (1+s) *math.pi*0.5) **2
alphas_cumprod=alphas_cumprod/alphas_cumprod[0]
betas=1- (alphas_cumprod[1:] /alphas_cumprod[:-1])
returntorch.clip(betas, 0, 0.999)
classGaussianDiffusion(nn.Module):
def__init__(
self,
model,
*,
image_size,
timesteps=1000,
sampling_timesteps=None,
objective='pred_noise',
beta_schedule='cosine',
ddim_sampling_eta=1.,
offset_noise_strength=0.,
min_snr_loss_weight=False,
min_snr_gamma=5,
use_cfg_plus_plus=False# https://arxiv.org/pdf/2406.08070
):
super().__init__()
assertnot (type(self) ==GaussianDiffusionandmodel.channels!=model.out_dim)
assertnotmodel.random_or_learned_sinusoidal_cond
self.model=model
self.channels=self.model.channels
self.image_size=image_size
self.objective=objective
assertobjectivein {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
ifbeta_schedule=='linear':
betas=linear_beta_schedule(timesteps)
elifbeta_schedule=='cosine':
betas=cosine_beta_schedule(timesteps)
else:
raiseValueError(f'unknown beta schedule {beta_schedule}')
alphas=1.-betas
alphas_cumprod=torch.cumprod(alphas, dim=0)
alphas_cumprod_prev=F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
timesteps, =betas.shape
self.num_timesteps=int(timesteps)
# use cfg++ when ddim sampling
self.use_cfg_plus_plus=use_cfg_plus_plus
# sampling related parameters
self.sampling_timesteps=default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
assertself.sampling_timesteps<=timesteps
self.is_ddim_sampling=self.sampling_timesteps<timesteps
self.ddim_sampling_eta=ddim_sampling_eta
# helper function to register buffer from float64 to float32
register_buffer=lambdaname, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.-alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1.-alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1./alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1./alphas_cumprod-1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance=betas* (1.-alphas_cumprod_prev) / (1.-alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min=1e-20)))
register_buffer('posterior_mean_coef1', betas*torch.sqrt(alphas_cumprod_prev) / (1.-alphas_cumprod))
register_buffer('posterior_mean_coef2', (1.-alphas_cumprod_prev) *torch.sqrt(alphas) / (1.-alphas_cumprod))
# offset noise strength - 0.1 was claimed ideal
self.offset_noise_strength=offset_noise_strength
# loss weight
snr=alphas_cumprod/ (1-alphas_cumprod)
maybe_clipped_snr=snr.clone()
ifmin_snr_loss_weight:
maybe_clipped_snr.clamp_(max=min_snr_gamma)
ifobjective=='pred_noise':
loss_weight=maybe_clipped_snr/snr
elifobjective=='pred_x0':
loss_weight=maybe_clipped_snr
elifobjective=='pred_v':
loss_weight=maybe_clipped_snr/ (snr+1)
register_buffer('loss_weight', loss_weight)
@property
defdevice(self):
returnself.betas.device
defpredict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) *x_t-
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) *noise
)
defpredict_noise_from_start(self, x_t, t, x0):
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) *x_t-x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
defpredict_v(self, x_start, t, noise):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) *noise-
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) *x_start
)
defpredict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) *x_t-
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) *v
)
defq_posterior(self, x_start, x_t, t):
posterior_mean= (
extract(self.posterior_mean_coef1, t, x_t.shape) *x_start+
extract(self.posterior_mean_coef2, t, x_t.shape) *x_t
)
posterior_variance=extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped=extract(self.posterior_log_variance_clipped, t, x_t.shape)
returnposterior_mean, posterior_variance, posterior_log_variance_clipped
defmodel_predictions(self, x, t, classes, cond_scale=6., rescaled_phi=0.7, clip_x_start=False):
model_output, model_output_null=self.model.forward_with_cond_scale(x, t, classes, cond_scale=cond_scale, rescaled_phi=rescaled_phi)
maybe_clip=partial(torch.clamp, min=-1., max=1.) ifclip_x_startelseidentity
ifself.objective=='pred_noise':
pred_noise=model_outputifnotself.use_cfg_plus_pluselsemodel_output_null
x_start=self.predict_start_from_noise(x, t, model_output)
x_start=maybe_clip(x_start)
elifself.objective=='pred_x0':
x_start=model_output
x_start=maybe_clip(x_start)
x_start_for_pred_noise=x_startifnotself.use_cfg_plus_pluselsemaybe_clip(model_output_null)
pred_noise=self.predict_noise_from_start(x, t, x_start_for_pred_noise)
elifself.objective=='pred_v':
v=model_output
x_start=self.predict_start_from_v(x, t, v)
x_start=maybe_clip(x_start)
x_start_for_pred_noise=x_start
ifself.use_cfg_plus_plus:
x_start_for_pred_noise=self.predict_start_from_v(x, t, model_output_null)
x_start_for_pred_noise=maybe_clip(x_start_for_pred_noise)
pred_noise=self.predict_noise_from_start(x, t, x_start_for_pred_noise)
returnModelPrediction(pred_noise, x_start)
defp_mean_variance(self, x, t, classes, cond_scale, rescaled_phi, clip_denoised=True):
preds=self.model_predictions(x, t, classes, cond_scale, rescaled_phi)
x_start=preds.pred_x_start
ifclip_denoised:
x_start.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance=self.q_posterior(x_start=x_start, x_t=x, t=t)
returnmodel_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
defp_sample(self, x, t: int, classes, cond_scale=6., rescaled_phi=0.7, clip_denoised=True):
b, *_, device=*x.shape, x.device
batched_times=torch.full((x.shape[0],), t, device=x.device, dtype=torch.long)
model_mean, _, model_log_variance, x_start=self.p_mean_variance(x=x, t=batched_times, classes=classes, cond_scale=cond_scale, rescaled_phi=rescaled_phi, clip_denoised=clip_denoised)
noise=torch.randn_like(x) ift>0else0.# no noise if t == 0
pred_img=model_mean+ (0.5*model_log_variance).exp() *noise
returnpred_img, x_start
@torch.no_grad()
defp_sample_loop(self, classes, shape, cond_scale=6., rescaled_phi=0.7):
batch, device=shape[0], self.betas.device
img=torch.randn(shape, device=device)
x_start=None
fortintqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img, x_start=self.p_sample(img, t, classes, cond_scale, rescaled_phi)
img=unnormalize_to_zero_to_one(img)
returnimg
@torch.no_grad()
defddim_sample(self, classes, shape, cond_scale=6., rescaled_phi=0.7, clip_denoised=True):
batch, device, total_timesteps, sampling_timesteps, eta, objective=shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
times=torch.linspace(-1, total_timesteps-1, steps=sampling_timesteps+1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times=list(reversed(times.int().tolist()))
time_pairs=list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
img=torch.randn(shape, device=device)
x_start=None
fortime, time_nextintqdm(time_pairs, desc='sampling loop time step'):
time_cond=torch.full((batch,), time, device=device, dtype=torch.long)
pred_noise, x_start, *_=self.model_predictions(img, time_cond, classes, cond_scale=cond_scale, rescaled_phi=rescaled_phi, clip_x_start=clip_denoised)
iftime_next<0:
img=x_start
continue
alpha=self.alphas_cumprod[time]
alpha_next=self.alphas_cumprod[time_next]
sigma=eta* ((1-alpha/alpha_next) * (1-alpha_next) / (1-alpha)).sqrt()
c= (1-alpha_next-sigma**2).sqrt()
noise=torch.randn_like(img)
img=x_start*alpha_next.sqrt() + \
c*pred_noise+ \
sigma*noise
img=unnormalize_to_zero_to_one(img)
returnimg
@torch.no_grad()
defsample(self, classes, cond_scale=6., rescaled_phi=0.7):
batch_size, image_size, channels=classes.shape[0], self.image_size, self.channels
sample_fn=self.p_sample_loopifnotself.is_ddim_samplingelseself.ddim_sample
returnsample_fn(classes, (batch_size, channels, image_size, image_size), cond_scale, rescaled_phi)
@torch.no_grad()
definterpolate(self, x1, x2, classes, t=None, lam=0.5):
b, *_, device=*x1.shape, x1.device
t=default(t, self.num_timesteps-1)
assertx1.shape==x2.shape
t_batched=torch.stack([torch.tensor(t, device=device)] *b)
xt1, xt2=map(lambdax: self.q_sample(x, t=t_batched), (x1, x2))
img= (1-lam) *xt1+lam*xt2
foriintqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):
img, _=self.p_sample(img, i, classes)
returnimg
@autocast('cuda', enabled=False)
defq_sample(self, x_start, t, noise=None):
noise=default(noise, lambda: torch.randn_like(x_start))
ifself.offset_noise_strength>0.:
offset_noise=torch.randn(x_start.shape[:2], device=self.device)
noise+=self.offset_noise_strength*rearrange(offset_noise, 'b c -> b c 1 1')
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) *x_start+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) *noise
)
defp_losses(self, x_start, t, *, classes, noise=None):
b, c, h, w=x_start.shape
noise=default(noise, lambda: torch.randn_like(x_start))
# noise sample
x=self.q_sample(x_start=x_start, t=t, noise=noise)
# predict and take gradient step
model_out=self.model(x, t, classes)
ifself.objective=='pred_noise':
target=noise
elifself.objective=='pred_x0':
target=x_start
elifself.objective=='pred_v':
v=self.predict_v(x_start, t, noise)
target=v
else:
raiseValueError(f'unknown objective {self.objective}')
loss=F.mse_loss(model_out, target, reduction='none')
loss=reduce(loss, 'b ... -> b', 'mean')
loss=loss*extract(self.loss_weight, t, loss.shape)
returnloss.mean()
defforward(self, img, *args, **kwargs):
b, c, h, w, device, img_size, =*img.shape, img.device, self.image_size
asserth==img_sizeandw==img_size, f'height and width of image must be {img_size}'
t=torch.randint(0, self.num_timesteps, (b,), device=device).long()
img=normalize_to_neg_one_to_one(img)
returnself.p_losses(img, t, *args, **kwargs)
# example
if__name__=='__main__':
num_classes=10
model=Unet(
dim=64,
dim_mults= (1, 2, 4, 8),
num_classes=num_classes,
cond_drop_prob=0.5
)
diffusion=GaussianDiffusion(
model,
image_size=128,
timesteps=1000
).cuda()
training_images=torch.randn(8, 3, 128, 128).cuda() # images are normalized from 0 to 1
image_classes=torch.randint(0, num_classes, (8,)).cuda() # say 10 classes
loss=diffusion(training_images, classes=image_classes)
loss.backward()
# do above for many steps
sampled_images=diffusion.sample(
classes=image_classes,
cond_scale=6.# condition scaling, anything greater than 1 strengthens the classifier free guidance. reportedly 3-8 is good empirically
)
sampled_images.shape# (8, 3, 128, 128)
# interpolation
interpolate_out=diffusion.interpolate(
training_images[:1],
training_images[:1],
image_classes[:1]
)