Skip to content

Latest commit

 

History

History
438 lines (319 loc) · 15.9 KB

inference_with_tcd_lora.md

File metadata and controls

438 lines (319 loc) · 15.9 KB

[[open-in-colab]]

Trajectory Consistency Distillation-LoRA

Trajectory Consistency Distillation (TCD) enables a model to generate higher quality and more detailed images with fewer steps. Moreover, owing to the effective error mitigation during the distillation process, TCD demonstrates superior performance even under conditions of large inference steps.

The major advantages of TCD are:

  • Better than Teacher: TCD demonstrates superior generative quality at both small and large inference steps and exceeds the performance of DPM-Solver++(2S) with Stable Diffusion XL (SDXL). There is no additional discriminator or LPIPS supervision included during TCD training.

  • Flexible Inference Steps: The inference steps for TCD sampling can be freely adjusted without adversely affecting the image quality.

  • Freely change detail level: During inference, the level of detail in the image can be adjusted with a single hyperparameter, gamma.

Tip

For more technical details of TCD, please refer to the paper or official project page).

For large models like SDXL, TCD is trained with LoRA to reduce memory usage. This is also useful because you can reuse LoRAs between different finetuned models, as long as they share the same base model, without further training.

This guide will show you how to perform inference with TCD-LoRAs for a variety of tasks like text-to-image and inpainting, as well as how you can easily combine TCD-LoRAs with other adapters. Choose one of the supported base model and it's corresponding TCD-LoRA checkpoint from the table below to get started.

Base modelTCD-LoRA checkpoint
stable-diffusion-v1-5TCD-SD15
stable-diffusion-2-1-baseTCD-SD21-base
stable-diffusion-xl-base-1.0TCD-SDXL

Make sure you have PEFT installed for better LoRA support.

pip install -U peft

General tasks

In this guide, let's use the [StableDiffusionXLPipeline] and the [TCDScheduler]. Use the [~StableDiffusionPipeline.load_lora_weights] method to load the SDXL-compatible TCD-LoRA weights.

A few tips to keep in mind for TCD-LoRA inference are to:

  • Keep the num_inference_steps between 4 and 50
  • Set eta (used to control stochasticity at each step) between 0 and 1. You should use a higher eta when increasing the number of inference steps, but the downside is that a larger eta in [TCDScheduler] leads to blurrier images. A value of 0.3 is recommended to produce good results.
importtorchfromdiffusersimportStableDiffusionXLPipeline, TCDSchedulerdevice="cuda"base_model_id="stabilityai/stable-diffusion-xl-base-1.0"tcd_lora_id="h1t/TCD-SDXL-LoRA"pipe=StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device) pipe.scheduler=TCDScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights(tcd_lora_id) pipe.fuse_lora() prompt="Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna."image=pipe( prompt=prompt, num_inference_steps=4, guidance_scale=0, eta=0.3, generator=torch.Generator(device=device).manual_seed(0), ).images[0]

importtorchfromdiffusersimportAutoPipelineForInpainting, TCDSchedulerfromdiffusers.utilsimportload_image, make_image_griddevice="cuda"base_model_id="diffusers/stable-diffusion-xl-1.0-inpainting-0.1"tcd_lora_id="h1t/TCD-SDXL-LoRA"pipe=AutoPipelineForInpainting.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device) pipe.scheduler=TCDScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights(tcd_lora_id) pipe.fuse_lora() img_url="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"mask_url="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"init_image=load_image(img_url).resize((1024, 1024)) mask_image=load_image(mask_url).resize((1024, 1024)) prompt="a tiger sitting on a park bench"image=pipe( prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=8, guidance_scale=0, eta=0.3, strength=0.99, # make sure to use `strength` below 1.0generator=torch.Generator(device=device).manual_seed(0), ).images[0] grid_image=make_image_grid([init_image, mask_image, image], rows=1, cols=3)

Community models

TCD-LoRA also works with many community finetuned models and plugins. For example, load the animagine-xl-3.0 checkpoint which is a community finetuned version of SDXL for generating anime images.

importtorchfromdiffusersimportStableDiffusionXLPipeline, TCDSchedulerdevice="cuda"base_model_id="cagliostrolab/animagine-xl-3.0"tcd_lora_id="h1t/TCD-SDXL-LoRA"pipe=StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device) pipe.scheduler=TCDScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights(tcd_lora_id) pipe.fuse_lora() prompt="A man, clad in a meticulously tailored military uniform, stands with unwavering resolve. The uniform boasts intricate details, and his eyes gleam with determination. Strands of vibrant, windswept hair peek out from beneath the brim of his cap."image=pipe( prompt=prompt, num_inference_steps=8, guidance_scale=0, eta=0.3, generator=torch.Generator(device=device).manual_seed(0), ).images[0]

TCD-LoRA also supports other LoRAs trained on different styles. For example, let's load the TheLastBen/Papercut_SDXL LoRA and fuse it with the TCD-LoRA with the [~loaders.UNet2DConditionLoadersMixin.set_adapters] method.

Tip

Check out the Merge LoRAs guide to learn more about efficient merging methods.

importtorchfromdiffusersimportStableDiffusionXLPipelinefromscheduling_tcdimportTCDSchedulerdevice="cuda"base_model_id="stabilityai/stable-diffusion-xl-base-1.0"tcd_lora_id="h1t/TCD-SDXL-LoRA"styled_lora_id="TheLastBen/Papercut_SDXL"pipe=StableDiffusionXLPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to(device) pipe.scheduler=TCDScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights(tcd_lora_id, adapter_name="tcd") pipe.load_lora_weights(styled_lora_id, adapter_name="style") pipe.set_adapters(["tcd", "style"], adapter_weights=[1.0, 1.0]) prompt="papercut of a winter mountain, snow"image=pipe( prompt=prompt, num_inference_steps=4, guidance_scale=0, eta=0.3, generator=torch.Generator(device=device).manual_seed(0), ).images[0]

Adapters

TCD-LoRA is very versatile, and it can be combined with other adapter types like ControlNets, IP-Adapter, and AnimateDiff.

Depth ControlNet

importtorchimportnumpyasnpfromPILimportImagefromtransformersimportDPTImageProcessor, DPTForDepthEstimationfromdiffusersimportControlNetModel, StableDiffusionXLControlNetPipelinefromdiffusers.utilsimportload_image, make_image_gridfromscheduling_tcdimportTCDSchedulerdevice="cuda"depth_estimator=DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device) feature_extractor=DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas") defget_depth_map(image): image=feature_extractor(images=image, return_tensors="pt").pixel_values.to(device) withtorch.no_grad(), torch.autocast(device): depth_map=depth_estimator(image).predicted_depthdepth_map=torch.nn.functional.interpolate( depth_map.unsqueeze(1), size=(1024, 1024), mode="bicubic", align_corners=False, ) depth_min=torch.amin(depth_map, dim=[1, 2, 3], keepdim=True) depth_max=torch.amax(depth_map, dim=[1, 2, 3], keepdim=True) depth_map= (depth_map-depth_min) / (depth_max-depth_min) image=torch.cat([depth_map] *3, dim=1) image=image.permute(0, 2, 3, 1).cpu().numpy()[0] image=Image.fromarray((image*255.0).clip(0, 255).astype(np.uint8)) returnimagebase_model_id="stabilityai/stable-diffusion-xl-base-1.0"controlnet_id="diffusers/controlnet-depth-sdxl-1.0"tcd_lora_id="h1t/TCD-SDXL-LoRA"controlnet=ControlNetModel.from_pretrained( controlnet_id, torch_dtype=torch.float16, variant="fp16", ) pipe=StableDiffusionXLControlNetPipeline.from_pretrained( base_model_id, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", ) pipe.enable_model_cpu_offload() pipe.scheduler=TCDScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights(tcd_lora_id) pipe.fuse_lora() prompt="stormtrooper lecture, photorealistic"image=load_image("https://huggingface.co/lllyasviel/sd-controlnet-depth/resolve/main/images/stormtrooper.png") depth_image=get_depth_map(image) controlnet_conditioning_scale=0.5# recommended for good generalizationimage=pipe( prompt, image=depth_image, num_inference_steps=4, guidance_scale=0, eta=0.3, controlnet_conditioning_scale=controlnet_conditioning_scale, generator=torch.Generator(device=device).manual_seed(0), ).images[0] grid_image=make_image_grid([depth_image, image], rows=1, cols=2)

Canny ControlNet

importtorchfromdiffusersimportControlNetModel, StableDiffusionXLControlNetPipelinefromdiffusers.utilsimportload_image, make_image_gridfromscheduling_tcdimportTCDSchedulerdevice="cuda"base_model_id="stabilityai/stable-diffusion-xl-base-1.0"controlnet_id="diffusers/controlnet-canny-sdxl-1.0"tcd_lora_id="h1t/TCD-SDXL-LoRA"controlnet=ControlNetModel.from_pretrained( controlnet_id, torch_dtype=torch.float16, variant="fp16", ) pipe=StableDiffusionXLControlNetPipeline.from_pretrained( base_model_id, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16", ) pipe.enable_model_cpu_offload() pipe.scheduler=TCDScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights(tcd_lora_id) pipe.fuse_lora() prompt="ultrarealistic shot of a furry blue bird"canny_image=load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png") controlnet_conditioning_scale=0.5# recommended for good generalizationimage=pipe( prompt, image=canny_image, num_inference_steps=4, guidance_scale=0, eta=0.3, controlnet_conditioning_scale=controlnet_conditioning_scale, generator=torch.Generator(device=device).manual_seed(0), ).images[0] grid_image=make_image_grid([canny_image, image], rows=1, cols=2)

The inference parameters in this example might not work for all examples, so we recommend you to try different values for `num_inference_steps`, `guidance_scale`, `controlnet_conditioning_scale` and `cross_attention_kwargs` parameters and choose the best one.

This example shows how to use the TCD-LoRA with the IP-Adapter and SDXL.

importtorchfromdiffusersimportStableDiffusionXLPipelinefromdiffusers.utilsimportload_image, make_image_gridfromip_adapterimportIPAdapterXLfromscheduling_tcdimportTCDSchedulerdevice="cuda"base_model_path="stabilityai/stable-diffusion-xl-base-1.0"image_encoder_path="sdxl_models/image_encoder"ip_ckpt="sdxl_models/ip-adapter_sdxl.bin"tcd_lora_id="h1t/TCD-SDXL-LoRA"pipe=StableDiffusionXLPipeline.from_pretrained( base_model_path, torch_dtype=torch.float16, variant="fp16" ) pipe.scheduler=TCDScheduler.from_config(pipe.scheduler.config) pipe.load_lora_weights(tcd_lora_id) pipe.fuse_lora() ip_model=IPAdapterXL(pipe, image_encoder_path, ip_ckpt, device) ref_image=load_image("https://raw.githubusercontent.com/tencent-ailab/IP-Adapter/main/assets/images/woman.png").resize((512, 512)) prompt="best quality, high quality, wearing sunglasses"image=ip_model.generate( pil_image=ref_image, prompt=prompt, scale=0.5, num_samples=1, num_inference_steps=4, guidance_scale=0, eta=0.3, seed=0, )[0] grid_image=make_image_grid([ref_image, image], rows=1, cols=2)

[AnimateDiff] allows animating images using Stable Diffusion models. TCD-LoRA can substantially accelerate the process without degrading image quality. The quality of animation with TCD-LoRA and AnimateDiff has a more lucid outcome.

importtorchfromdiffusersimportMotionAdapter, AnimateDiffPipeline, DDIMSchedulerfromscheduling_tcdimportTCDSchedulerfromdiffusers.utilsimportexport_to_gifadapter=MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5") pipe=AnimateDiffPipeline.from_pretrained( "frankjoshua/toonyou_beta6", motion_adapter=adapter, ).to("cuda") # set TCDSchedulerpipe.scheduler=TCDScheduler.from_config(pipe.scheduler.config) # load TCD LoRApipe.load_lora_weights("h1t/TCD-SD15-LoRA", adapter_name="tcd") pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-in", weight_name="diffusion_pytorch_model.safetensors", adapter_name="motion-lora") pipe.set_adapters(["tcd", "motion-lora"], adapter_weights=[1.0, 1.2]) prompt="best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress"generator=torch.manual_seed(0) frames=pipe( prompt=prompt, num_inference_steps=5, guidance_scale=0, cross_attention_kwargs={"scale": 1}, num_frames=24, eta=0.3, generator=generator ).frames[0] export_to_gif(frames, "animation.gif")

close