- Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathelucidated_diffusion.py
277 lines (190 loc) · 8.99 KB
/
elucidated_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
frommathimportsqrt
fromrandomimportrandom
importtorch
fromtorchimportnn, einsum
importtorch.nn.functionalasF
fromtqdmimporttqdm
fromeinopsimportrearrange, repeat, reduce
# helpers
defexists(val):
returnvalisnotNone
defdefault(val, d):
ifexists(val):
returnval
returnd() ifcallable(d) elsed
# tensor helpers
deflog(t, eps=1e-20):
returntorch.log(t.clamp(min=eps))
# normalization functions
defnormalize_to_neg_one_to_one(img):
returnimg*2-1
defunnormalize_to_zero_to_one(t):
return (t+1) *0.5
# main class
classElucidatedDiffusion(nn.Module):
def__init__(
self,
net,
*,
image_size,
channels=3,
num_sample_steps=32, # number of sampling steps
sigma_min=0.002, # min noise level
sigma_max=80, # max noise level
sigma_data=0.5, # standard deviation of data distribution
rho=7, # controls the sampling schedule
P_mean=-1.2, # mean of log-normal distribution from which noise is drawn for training
P_std=1.2, # standard deviation of log-normal distribution from which noise is drawn for training
S_churn=80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
S_tmin=0.05,
S_tmax=50,
S_noise=1.003,
):
super().__init__()
assertnet.random_or_learned_sinusoidal_cond
self.self_condition=net.self_condition
self.net=net
# image dimensions
self.channels=channels
self.image_size=image_size
# parameters
self.sigma_min=sigma_min
self.sigma_max=sigma_max
self.sigma_data=sigma_data
self.rho=rho
self.P_mean=P_mean
self.P_std=P_std
self.num_sample_steps=num_sample_steps# otherwise known as N in the paper
self.S_churn=S_churn
self.S_tmin=S_tmin
self.S_tmax=S_tmax
self.S_noise=S_noise
@property
defdevice(self):
returnnext(self.net.parameters()).device
# derived preconditioning params - Table 1
defc_skip(self, sigma):
return (self.sigma_data**2) / (sigma**2+self.sigma_data**2)
defc_out(self, sigma):
returnsigma*self.sigma_data* (self.sigma_data**2+sigma**2) **-0.5
defc_in(self, sigma):
return1* (sigma**2+self.sigma_data**2) **-0.5
defc_noise(self, sigma):
returnlog(sigma) *0.25
# preconditioned network output
# equation (7) in the paper
defpreconditioned_network_forward(self, noised_images, sigma, self_cond=None, clamp=False):
batch, device=noised_images.shape[0], noised_images.device
ifisinstance(sigma, float):
sigma=torch.full((batch,), sigma, device=device)
padded_sigma=rearrange(sigma, 'b -> b 1 1 1')
net_out=self.net(
self.c_in(padded_sigma) *noised_images,
self.c_noise(sigma),
self_cond
)
out=self.c_skip(padded_sigma) *noised_images+self.c_out(padded_sigma) *net_out
ifclamp:
out=out.clamp(-1., 1.)
returnout
# sampling
# sample schedule
# equation (5) in the paper
defsample_schedule(self, num_sample_steps=None):
num_sample_steps=default(num_sample_steps, self.num_sample_steps)
N=num_sample_steps
inv_rho=1/self.rho
steps=torch.arange(num_sample_steps, device=self.device, dtype=torch.float32)
sigmas= (self.sigma_max**inv_rho+steps/ (N-1) * (self.sigma_min**inv_rho-self.sigma_max**inv_rho)) **self.rho
sigmas=F.pad(sigmas, (0, 1), value=0.) # last step is sigma value of 0.
returnsigmas
@torch.no_grad()
defsample(self, batch_size=16, num_sample_steps=None, clamp=True):
num_sample_steps=default(num_sample_steps, self.num_sample_steps)
shape= (batch_size, self.channels, self.image_size, self.image_size)
# get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
sigmas=self.sample_schedule(num_sample_steps)
gammas=torch.where(
(sigmas>=self.S_tmin) & (sigmas<=self.S_tmax),
min(self.S_churn/num_sample_steps, sqrt(2) -1),
0.
)
sigmas_and_gammas=list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
# images is noise at the beginning
init_sigma=sigmas[0]
images=init_sigma*torch.randn(shape, device=self.device)
# for self conditioning
x_start=None
# gradually denoise
forsigma, sigma_next, gammaintqdm(sigmas_and_gammas, desc='sampling time step'):
sigma, sigma_next, gamma=map(lambdat: t.item(), (sigma, sigma_next, gamma))
eps=self.S_noise*torch.randn(shape, device=self.device) # stochastic sampling
sigma_hat=sigma+gamma*sigma
images_hat=images+sqrt(sigma_hat**2-sigma**2) *eps
self_cond=x_startifself.self_conditionelseNone
model_output=self.preconditioned_network_forward(images_hat, sigma_hat, self_cond, clamp=clamp)
denoised_over_sigma= (images_hat-model_output) /sigma_hat
images_next=images_hat+ (sigma_next-sigma_hat) *denoised_over_sigma
# second order correction, if not the last timestep
ifsigma_next!=0:
self_cond=model_outputifself.self_conditionelseNone
model_output_next=self.preconditioned_network_forward(images_next, sigma_next, self_cond, clamp=clamp)
denoised_prime_over_sigma= (images_next-model_output_next) /sigma_next
images_next=images_hat+0.5* (sigma_next-sigma_hat) * (denoised_over_sigma+denoised_prime_over_sigma)
images=images_next
x_start=model_output_nextifsigma_next!=0elsemodel_output
images=images.clamp(-1., 1.)
returnunnormalize_to_zero_to_one(images)
@torch.no_grad()
defsample_using_dpmpp(self, batch_size=16, num_sample_steps=None):
"""
thanks to Katherine Crowson (https://github.com/crowsonkb) for figuring it all out!
https://arxiv.org/abs/2211.01095
"""
device, num_sample_steps=self.device, default(num_sample_steps, self.num_sample_steps)
sigmas=self.sample_schedule(num_sample_steps)
shape= (batch_size, self.channels, self.image_size, self.image_size)
images=sigmas[0] *torch.randn(shape, device=device)
sigma_fn=lambdat: t.neg().exp()
t_fn=lambdasigma: sigma.log().neg()
old_denoised=None
foriintqdm(range(len(sigmas) -1)):
denoised=self.preconditioned_network_forward(images, sigmas[i].item())
t, t_next=t_fn(sigmas[i]), t_fn(sigmas[i+1])
h=t_next-t
ifnotexists(old_denoised) orsigmas[i+1] ==0:
denoised_d=denoised
else:
h_last=t-t_fn(sigmas[i-1])
r=h_last/h
gamma=-1/ (2*r)
denoised_d= (1-gamma) *denoised+gamma*old_denoised
images= (sigma_fn(t_next) /sigma_fn(t)) *images- (-h).expm1() *denoised_d
old_denoised=denoised
images=images.clamp(-1., 1.)
returnunnormalize_to_zero_to_one(images)
# training
defloss_weight(self, sigma):
return (sigma**2+self.sigma_data**2) * (sigma*self.sigma_data) **-2
defnoise_distribution(self, batch_size):
return (self.P_mean+self.P_std*torch.randn((batch_size,), device=self.device)).exp()
defforward(self, images):
batch_size, c, h, w, device, image_size, channels=*images.shape, images.device, self.image_size, self.channels
asserth==image_sizeandw==image_size, f'height and width of image must be {image_size}'
assertc==channels, 'mismatch of image channels'
images=normalize_to_neg_one_to_one(images)
sigmas=self.noise_distribution(batch_size)
padded_sigmas=rearrange(sigmas, 'b -> b 1 1 1')
noise=torch.randn_like(images)
noised_images=images+padded_sigmas*noise# alphas are 1. in the paper
self_cond=None
ifself.self_conditionandrandom() <0.5:
# from hinton's group's bit diffusion paper
withtorch.no_grad():
self_cond=self.preconditioned_network_forward(noised_images, sigmas)
self_cond.detach_()
denoised=self.preconditioned_network_forward(noised_images, sigmas, self_cond)
losses=F.mse_loss(denoised, images, reduction='none')
losses=reduce(losses, 'b ... -> b', 'mean')
losses=losses*self.loss_weight(sigmas)
returnlosses.mean()