- Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathlearned_gaussian_diffusion.py
155 lines (107 loc) · 5.63 KB
/
learned_gaussian_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
importtorch
fromcollectionsimportnamedtuple
frommathimportpi, sqrt, logasln
frominspectimportisfunction
fromtorchimportnn, einsum
fromeinopsimportrearrange
fromdenoising_diffusion_pytorch.denoising_diffusion_pytorchimportGaussianDiffusion, extract, unnormalize_to_zero_to_one
# constants
NAT=1./ln(2)
ModelPrediction=namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start', 'pred_variance'])
# helper functions
defexists(x):
returnxisnotNone
defdefault(val, d):
ifexists(val):
returnval
returnd() ifisfunction(d) elsed
# tensor helpers
deflog(t, eps=1e-15):
returntorch.log(t.clamp(min=eps))
defmeanflat(x):
returnx.mean(dim=tuple(range(1, len(x.shape))))
defnormal_kl(mean1, logvar1, mean2, logvar2):
"""
KL divergence between normal distributions parameterized by mean and log-variance.
"""
return0.5* (-1.0+logvar2-logvar1+torch.exp(logvar1-logvar2) + ((mean1-mean2) **2) *torch.exp(-logvar2))
defapprox_standard_normal_cdf(x):
return0.5* (1.0+torch.tanh(sqrt(2.0/pi) * (x+0.044715* (x**3))))
defdiscretized_gaussian_log_likelihood(x, *, means, log_scales, thres=0.999):
assertx.shape==means.shape==log_scales.shape
centered_x=x-means
inv_stdv=torch.exp(-log_scales)
plus_in=inv_stdv* (centered_x+1./255.)
cdf_plus=approx_standard_normal_cdf(plus_in)
min_in=inv_stdv* (centered_x-1./255.)
cdf_min=approx_standard_normal_cdf(min_in)
log_cdf_plus=log(cdf_plus)
log_one_minus_cdf_min=log(1.-cdf_min)
cdf_delta=cdf_plus-cdf_min
log_probs=torch.where(x<-thres,
log_cdf_plus,
torch.where(x>thres,
log_one_minus_cdf_min,
log(cdf_delta)))
returnlog_probs
# https://arxiv.org/abs/2102.09672
# i thought the results were questionable, if one were to focus only on FID
# but may as well get this in here for others to try, as GLIDE is using it (and DALL-E2 first stage of cascade)
# gaussian diffusion for learned variance + hybrid eps simple + vb loss
classLearnedGaussianDiffusion(GaussianDiffusion):
def__init__(
self,
model,
vb_loss_weight=0.001, # lambda was 0.001 in the paper
*args,
**kwargs
):
super().__init__(model, *args, **kwargs)
assertmodel.out_dim== (model.channels*2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
assertnotmodel.self_condition, 'not supported yet'
self.vb_loss_weight=vb_loss_weight
defmodel_predictions(self, x, t, x_self_cond=None, clip_x_start=False, rederive_pred_noise=False):
model_output=self.model(x, t)
model_output, pred_variance=model_output.chunk(2, dim=1)
maybe_clip=partial(torch.clamp, min=-1., max=1.) ifclip_x_startelseidentity
ifself.objective=='pred_noise':
pred_noise=model_output
x_start=self.predict_start_from_noise(x, t, model_output)
elifself.objective=='pred_x0':
pred_noise=self.predict_noise_from_start(x, t, model_output)
x_start=model_output
x_start=maybe_clip(x_start)
returnModelPrediction(pred_noise, x_start, pred_variance)
defp_mean_variance(self, *, x, t, clip_denoised, model_output=None, **kwargs):
model_output=default(model_output, lambda: self.model(x, t))
pred_noise, var_interp_frac_unnormalized=model_output.chunk(2, dim=1)
min_log=extract(self.posterior_log_variance_clipped, t, x.shape)
max_log=extract(torch.log(self.betas), t, x.shape)
var_interp_frac=unnormalize_to_zero_to_one(var_interp_frac_unnormalized)
model_log_variance=var_interp_frac*max_log+ (1-var_interp_frac) *min_log
model_variance=model_log_variance.exp()
x_start=self.predict_start_from_noise(x, t, pred_noise)
ifclip_denoised:
x_start.clamp_(-1., 1.)
model_mean, _, _=self.q_posterior(x_start, x, t)
returnmodel_mean, model_variance, model_log_variance, x_start
defp_losses(self, x_start, t, noise=None, clip_denoised=False):
noise=default(noise, lambda: torch.randn_like(x_start))
x_t=self.q_sample(x_start=x_start, t=t, noise=noise)
# model output
model_output=self.model(x_t, t)
# calculating kl loss for learned variance (interpolation)
true_mean, _, true_log_variance_clipped=self.q_posterior(x_start=x_start, x_t=x_t, t=t)
model_mean, _, model_log_variance, _=self.p_mean_variance(x=x_t, t=t, clip_denoised=clip_denoised, model_output=model_output)
# kl loss with detached model predicted mean, for stability reasons as in paper
detached_model_mean=model_mean.detach()
kl=normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance)
kl=meanflat(kl) *NAT
decoder_nll=-discretized_gaussian_log_likelihood(x_start, means=detached_model_mean, log_scales=0.5*model_log_variance)
decoder_nll=meanflat(decoder_nll) *NAT
# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
vb_losses=torch.where(t==0, decoder_nll, kl)
# simple loss - predicting noise, x0, or x_prev
pred_noise, _=model_output.chunk(2, dim=1)
simple_losses=F.mse_loss(pred_noise, noise)
returnsimple_losses+vb_losses.mean() *self.vb_loss_weight