- Notifications
You must be signed in to change notification settings - Fork 5.9k
/
Copy pathpipeline_flax_stable_diffusion.py
477 lines (407 loc) · 20.3 KB
/
pipeline_flax_stable_diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
importwarnings
fromfunctoolsimportpartial
fromtypingimportDict, List, Optional, Union
importjax
importjax.numpyasjnp
importnumpyasnp
fromflax.core.frozen_dictimportFrozenDict
fromflax.jax_utilsimportunreplicate
fromflax.training.common_utilsimportshard
frompackagingimportversion
fromPILimportImage
fromtransformersimportCLIPImageProcessor, CLIPTokenizer, FlaxCLIPTextModel
from ...modelsimportFlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...schedulersimport (
FlaxDDIMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ...utilsimportdeprecate, logging, replace_example_docstring
from ..pipeline_flax_utilsimportFlaxDiffusionPipeline
from .pipeline_outputimportFlaxStableDiffusionPipelineOutput
from .safety_checker_flaximportFlaxStableDiffusionSafetyChecker
logger=logging.get_logger(__name__) # pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG=False
EXAMPLE_DOC_STRING="""
Examples:
```py
>>> import jax
>>> import numpy as np
>>> from flax.jax_utils import replicate
>>> from flax.training.common_utils import shard
>>> from diffusers import FlaxStableDiffusionPipeline
>>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
... "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="bf16", dtype=jax.numpy.bfloat16
... )
>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> prng_seed = jax.random.PRNGKey(0)
>>> num_inference_steps = 50
>>> num_samples = jax.device_count()
>>> prompt = num_samples * [prompt]
>>> prompt_ids = pipeline.prepare_inputs(prompt)
# shard inputs and rng
>>> params = replicate(params)
>>> prng_seed = jax.random.split(prng_seed, jax.device_count())
>>> prompt_ids = shard(prompt_ids)
>>> images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
>>> images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```
"""
classFlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
r"""
Flax-based pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
vae ([`FlaxAutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`~transformers.FlaxCLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`FlaxUNet2DConditionModel`]):
A `FlaxUNet2DConditionModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
def__init__(
self,
vae: FlaxAutoencoderKL,
text_encoder: FlaxCLIPTextModel,
tokenizer: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
safety_checker: FlaxStableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
dtype: jnp.dtype=jnp.float32,
):
super().__init__()
self.dtype=dtype
ifsafety_checkerisNone:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
is_unet_version_less_0_9_0= (
unetisnotNone
andhasattr(unet.config, "_diffusers_version")
andversion.parse(version.parse(unet.config._diffusers_version).base_version) <version.parse("0.9.0.dev0")
)
is_unet_sample_size_less_64= (
unetisnotNoneandhasattr(unet.config, "sample_size") andunet.config.sample_size<64
)
ifis_unet_version_less_0_9_0andis_unet_sample_size_less_64:
deprecation_message= (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
" \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config=dict(unet.config)
new_config["sample_size"] =64
unet._internal_dict=FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor=2** (len(self.vae.config.block_out_channels) -1) ifgetattr(self, "vae", None) else8
defprepare_inputs(self, prompt: Union[str, List[str]]):
ifnotisinstance(prompt, (str, list)):
raiseValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
text_input=self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
returntext_input.input_ids
def_get_has_nsfw_concepts(self, features, params):
has_nsfw_concepts=self.safety_checker(features, params)
returnhas_nsfw_concepts
def_run_safety_checker(self, images, safety_model_params, jit=False):
# safety_model_params should already be replicated when jit is True
pil_images= [Image.fromarray(image) forimageinimages]
features=self.feature_extractor(pil_images, return_tensors="np").pixel_values
ifjit:
features=shard(features)
has_nsfw_concepts=_p_get_has_nsfw_concepts(self, features, safety_model_params)
has_nsfw_concepts=unshard(has_nsfw_concepts)
safety_model_params=unreplicate(safety_model_params)
else:
has_nsfw_concepts=self._get_has_nsfw_concepts(features, safety_model_params)
images_was_copied=False
foridx, has_nsfw_conceptinenumerate(has_nsfw_concepts):
ifhas_nsfw_concept:
ifnotimages_was_copied:
images_was_copied=True
images=images.copy()
images[idx] =np.zeros(images[idx].shape, dtype=np.uint8) # black image
ifany(has_nsfw_concepts):
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)
returnimages, has_nsfw_concepts
def_generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
latents: Optional[jnp.ndarray] =None,
neg_prompt_ids: Optional[jnp.ndarray] =None,
):
ifheight%8!=0orwidth%8!=0:
raiseValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# get prompt text embeddings
prompt_embeds=self.text_encoder(prompt_ids, params=params["text_encoder"])[0]
# TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0`
# implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0`
batch_size=prompt_ids.shape[0]
max_length=prompt_ids.shape[-1]
ifneg_prompt_idsisNone:
uncond_input=self.tokenizer(
[""] *batch_size, padding="max_length", max_length=max_length, return_tensors="np"
).input_ids
else:
uncond_input=neg_prompt_ids
negative_prompt_embeds=self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context=jnp.concatenate([negative_prompt_embeds, prompt_embeds])
# Ensure model output will be `float32` before going into the scheduler
guidance_scale=jnp.array([guidance_scale], dtype=jnp.float32)
latents_shape= (
batch_size,
self.unet.config.in_channels,
height//self.vae_scale_factor,
width//self.vae_scale_factor,
)
iflatentsisNone:
latents=jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
iflatents.shape!=latents_shape:
raiseValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
defloop_body(step, args):
latents, scheduler_state=args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input=jnp.concatenate([latents] *2)
t=jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep=jnp.broadcast_to(t, latents_input.shape[0])
latents_input=self.scheduler.scale_model_input(scheduler_state, latents_input, t)
# predict the noise residual
noise_pred=self.unet.apply(
{"params": params["unet"]},
jnp.array(latents_input),
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=context,
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text=jnp.split(noise_pred, 2, axis=0)
noise_pred=noise_pred_uncond+guidance_scale* (noise_prediction_text-noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state=self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
returnlatents, scheduler_state
scheduler_state=self.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
)
# scale the initial noise by the standard deviation required by the scheduler
latents=latents*params["scheduler"].init_noise_sigma
ifDEBUG:
# run with python for loop
foriinrange(num_inference_steps):
latents, scheduler_state=loop_body(i, (latents, scheduler_state))
else:
latents, _=jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
# scale and decode the image latents with vae
latents=1/self.vae.config.scaling_factor*latents
image=self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image= (image/2+0.5).clip(0, 1).transpose(0, 2, 3, 1)
returnimage
@replace_example_docstring(EXAMPLE_DOC_STRING)
def__call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.Array,
num_inference_steps: int=50,
height: Optional[int] =None,
width: Optional[int] =None,
guidance_scale: Union[float, jnp.ndarray] =7.5,
latents: jnp.ndarray=None,
neg_prompt_ids: jnp.ndarray=None,
return_dict: bool=True,
jit: bool=False,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
latents (`jnp.ndarray`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
array is generated by sampling using the supplied random `generator`.
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
<Tip warning={true}>
This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
future release.
</Tip>
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
Examples:
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated images
and the second element is a list of `bool`s indicating whether the corresponding generated image
contains "not-safe-for-work" (nsfw) content.
"""
# 0. Default height and width to unet
height=heightorself.unet.config.sample_size*self.vae_scale_factor
width=widthorself.unet.config.sample_size*self.vae_scale_factor
ifisinstance(guidance_scale, float):
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
# shape information, as they may be sharded (when `jit` is `True`), or not.
guidance_scale=jnp.array([guidance_scale] *prompt_ids.shape[0])
iflen(prompt_ids.shape) >2:
# Assume sharded
guidance_scale=guidance_scale[:, None]
ifjit:
images=_p_generate(
self,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
)
else:
images=self._generate(
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
)
ifself.safety_checkerisnotNone:
safety_params=params["safety_checker"]
images_uint8_casted= (images*255).round().astype("uint8")
num_devices, batch_size=images.shape[:2]
images_uint8_casted=np.asarray(images_uint8_casted).reshape(num_devices*batch_size, height, width, 3)
images_uint8_casted, has_nsfw_concept=self._run_safety_checker(images_uint8_casted, safety_params, jit)
images=np.asarray(images).copy()
# block images
ifany(has_nsfw_concept):
fori, is_nsfwinenumerate(has_nsfw_concept):
ifis_nsfw:
images[i, 0] =np.asarray(images_uint8_casted[i])
images=images.reshape(num_devices, batch_size, height, width, 3)
else:
images=np.asarray(images)
has_nsfw_concept=False
ifnotreturn_dict:
return (images, has_nsfw_concept)
returnFlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation.
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
@partial(
jax.pmap,
in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0),
static_broadcasted_argnums=(0, 4, 5, 6),
)
def_p_generate(
pipe,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
):
returnpipe._generate(
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
)
@partial(jax.pmap, static_broadcasted_argnums=(0,))
def_p_get_has_nsfw_concepts(pipe, features, params):
returnpipe._get_has_nsfw_concepts(features, params)
defunshard(x: jnp.ndarray):
# einops.rearrange(x, 'd b ... -> (d b) ...')
num_devices, batch_size=x.shape[:2]
rest=x.shape[2:]
returnx.reshape(num_devices*batch_size, *rest)