Skip to content

Latest commit

 

History

History
497 lines (378 loc) · 17.8 KB

para_attn.md

File metadata and controls

497 lines (378 loc) · 17.8 KB

ParaAttention

Large image and video generation models, such as FLUX.1-dev and HunyuanVideo, can be an inference challenge for real-time applications and deployment because of their size.

ParaAttention is a library that implements context parallelism and first block cache, and can be combined with other techniques (torch.compile, fp8 dynamic quantization), to accelerate inference.

This guide will show you how to apply ParaAttention to FLUX.1-dev and HunyuanVideo on NVIDIA L20 GPUs. No optimizations are applied for our baseline benchmark, except for HunyuanVideo to avoid out-of-memory errors.

Our baseline benchmark shows that FLUX.1-dev is able to generate a 1024x1024 resolution image in 28 steps in 26.36 seconds, and HunyuanVideo is able to generate 129 frames at 720p resolution in 30 steps in 3675.71 seconds.

Tip

For even faster inference with context parallelism, try using NVIDIA A100 or H100 GPUs (if available) with NVLink support, especially when there is a large number of GPUs.

First Block Cache

Caching the output of the transformers blocks in the model and reusing them in the next inference steps reduces the computation cost and makes inference faster.

However, it is hard to decide when to reuse the cache to ensure quality generated images or videos. ParaAttention directly uses the residual difference of the first transformer block output to approximate the difference among model outputs. When the difference is small enough, the residual difference of previous inference steps is reused. In other words, the denoising step is skipped.

This achieves a 2x speedup on FLUX.1-dev and HunyuanVideo inference with very good quality.

Cache in Diffusion Transformer

How AdaCache works, First Block Cache is a variant of it

To apply first block cache on FLUX.1-dev, call apply_cache_on_pipe as shown below. 0.08 is the default residual difference value for FLUX models.

importtimeimporttorchfromdiffusersimportFluxPipelinepipe=FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, ).to("cuda") frompara_attn.first_block_cache.diffusers_adaptersimportapply_cache_on_pipeapply_cache_on_pipe(pipe, residual_diff_threshold=0.08) # Enable memory savings# pipe.enable_model_cpu_offload()# pipe.enable_sequential_cpu_offload()begin=time.time() image=pipe( "A cat holding a sign that says hello world", num_inference_steps=28, ).images[0] end=time.time() print(f"Time: {end-begin:.2f}s") print("Saving image to flux.png") image.save("flux.png")
OptimizationsOriginalFBCache rdt=0.06FBCache rdt=0.08FBCache rdt=0.10FBCache rdt=0.12
PreviewOriginalFBCache rdt=0.06FBCache rdt=0.08FBCache rdt=0.10FBCache rdt=0.12
Wall Time (s)26.3621.8317.0116.0013.78

First Block Cache reduced the inference speed to 17.01 seconds compared to the baseline, or 1.55x faster, while maintaining nearly zero quality loss.

To apply First Block Cache on HunyuanVideo, apply_cache_on_pipe as shown below. 0.06 is the default residual difference value for HunyuanVideo models.

importtimeimporttorchfromdiffusersimportHunyuanVideoPipeline, HunyuanVideoTransformer3DModelfromdiffusers.utilsimportexport_to_videomodel_id="tencent/HunyuanVideo"transformer=HunyuanVideoTransformer3DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=torch.bfloat16, revision="refs/pr/18", ) pipe=HunyuanVideoPipeline.from_pretrained( model_id, transformer=transformer, torch_dtype=torch.float16, revision="refs/pr/18", ).to("cuda") frompara_attn.first_block_cache.diffusers_adaptersimportapply_cache_on_pipeapply_cache_on_pipe(pipe, residual_diff_threshold=0.6) pipe.vae.enable_tiling() begin=time.time() output=pipe( prompt="A cat walks on the grass, realistic", height=720, width=1280, num_frames=129, num_inference_steps=30, ).frames[0] end=time.time() print(f"Time: {end-begin:.2f}s") print("Saving video to hunyuan_video.mp4") export_to_video(output, "hunyuan_video.mp4", fps=15)

HunyuanVideo without FBCache

HunyuanVideo with FBCache

First Block Cache reduced the inference speed to 2271.06 seconds compared to the baseline, or 1.62x faster, while maintaining nearly zero quality loss.

fp8 quantization

fp8 with dynamic quantization further speeds up inference and reduces memory usage. Both the activations and weights must be quantized in order to use the 8-bit NVIDIA Tensor Cores.

Use float8_weight_only and float8_dynamic_activation_float8_weight to quantize the text encoder and transformer model.

The default quantization method is per tensor quantization, but if your GPU supports row-wise quantization, you can also try it for better accuracy.

Install torchao with the command below.

pip3 install -U torch torchao

torch.compile with mode="max-autotune-no-cudagraphs" or mode="max-autotune" selects the best kernel for performance. Compilation can take a long time if it's the first time the model is called, but it is worth it once the model has been compiled.

This example only quantizes the transformer model, but you can also quantize the text encoder to reduce memory usage even more.

Tip

Dynamic quantization can significantly change the distribution of the model output, so you need to change the residual_diff_threshold to a larger value for it to take effect.

importtimeimporttorchfromdiffusersimportFluxPipelinepipe=FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, ).to("cuda") frompara_attn.first_block_cache.diffusers_adaptersimportapply_cache_on_pipeapply_cache_on_pipe( pipe, residual_diff_threshold=0.12, # Use a larger value to make the cache take effect ) fromtorchao.quantizationimportquantize_, float8_dynamic_activation_float8_weight, float8_weight_onlyquantize_(pipe.text_encoder, float8_weight_only()) quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) pipe.transformer=torch.compile( pipe.transformer, mode="max-autotune-no-cudagraphs", ) # Enable memory savings# pipe.enable_model_cpu_offload()# pipe.enable_sequential_cpu_offload()foriinrange(2): begin=time.time() image=pipe( "A cat holding a sign that says hello world", num_inference_steps=28, ).images[0] end=time.time() ifi==0: print(f"Warm up time: {end-begin:.2f}s") else: print(f"Time: {end-begin:.2f}s") print("Saving image to flux.png") image.save("flux.png")

fp8 dynamic quantization and torch.compile reduced the inference speed to 7.56 seconds compared to the baseline, or 3.48x faster.

importtimeimporttorchfromdiffusersimportHunyuanVideoPipeline, HunyuanVideoTransformer3DModelfromdiffusers.utilsimportexport_to_videomodel_id="tencent/HunyuanVideo"transformer=HunyuanVideoTransformer3DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=torch.bfloat16, revision="refs/pr/18", ) pipe=HunyuanVideoPipeline.from_pretrained( model_id, transformer=transformer, torch_dtype=torch.float16, revision="refs/pr/18", ).to("cuda") frompara_attn.first_block_cache.diffusers_adaptersimportapply_cache_on_pipeapply_cache_on_pipe(pipe) fromtorchao.quantizationimportquantize_, float8_dynamic_activation_float8_weight, float8_weight_onlyquantize_(pipe.text_encoder, float8_weight_only()) quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) pipe.transformer=torch.compile( pipe.transformer, mode="max-autotune-no-cudagraphs", ) # Enable memory savingspipe.vae.enable_tiling() # pipe.enable_model_cpu_offload()# pipe.enable_sequential_cpu_offload()foriinrange(2): begin=time.time() output=pipe( prompt="A cat walks on the grass, realistic", height=720, width=1280, num_frames=129, num_inference_steps=1ifi==0else30, ).frames[0] end=time.time() ifi==0: print(f"Warm up time: {end-begin:.2f}s") else: print(f"Time: {end-begin:.2f}s") print("Saving video to hunyuan_video.mp4") export_to_video(output, "hunyuan_video.mp4", fps=15)

A NVIDIA L20 GPU only has 48GB memory and could face out-of-memory (OOM) errors after compilation and if enable_model_cpu_offload isn't called because HunyuanVideo has very large activation tensors when running with high resolution and large number of frames. For GPUs with less than 80GB of memory, you can try reducing the resolution and number of frames to avoid OOM errors.

Large video generation models are usually bottlenecked by the attention computations rather than the fully connected layers. These models don't significantly benefit from quantization and torch.compile.

Context Parallelism

Context Parallelism parallelizes inference and scales with multiple GPUs. The ParaAttention compositional design allows you to combine Context Parallelism with First Block Cache and dynamic quantization.

Tip

Refer to the ParaAttention repository for detailed instructions and examples of how to scale inference with multiple GPUs.

If the inference process needs to be persistent and serviceable, it is suggested to use torch.multiprocessing to write your own inference processor. This can eliminate the overhead of launching the process and loading and recompiling the model.

The code sample below combines First Block Cache, fp8 dynamic quantization, torch.compile, and Context Parallelism for the fastest inference speed.

importtimeimporttorchimporttorch.distributedasdistfromdiffusersimportFluxPipelinedist.init_process_group() torch.cuda.set_device(dist.get_rank()) pipe=FluxPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, ).to("cuda") frompara_attn.context_parallelimportinit_context_parallel_meshfrompara_attn.context_parallel.diffusers_adaptersimportparallelize_pipefrompara_attn.parallel_vae.diffusers_adaptersimportparallelize_vaemesh=init_context_parallel_mesh( pipe.device.type, max_ring_dim_size=2, ) parallelize_pipe( pipe, mesh=mesh, ) parallelize_vae(pipe.vae, mesh=mesh._flatten()) frompara_attn.first_block_cache.diffusers_adaptersimportapply_cache_on_pipeapply_cache_on_pipe( pipe, residual_diff_threshold=0.12, # Use a larger value to make the cache take effect ) fromtorchao.quantizationimportquantize_, float8_dynamic_activation_float8_weight, float8_weight_onlyquantize_(pipe.text_encoder, float8_weight_only()) quantize_(pipe.transformer, float8_dynamic_activation_float8_weight()) torch._inductor.config.reorder_for_compute_comm_overlap=Truepipe.transformer=torch.compile( pipe.transformer, mode="max-autotune-no-cudagraphs", ) # Enable memory savings# pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())foriinrange(2): begin=time.time() image=pipe( "A cat holding a sign that says hello world", num_inference_steps=28, output_type="pil"ifdist.get_rank() ==0else"pt", ).images[0] end=time.time() ifdist.get_rank() ==0: ifi==0: print(f"Warm up time: {end-begin:.2f}s") else: print(f"Time: {end-begin:.2f}s") ifdist.get_rank() ==0: print("Saving image to flux.png") image.save("flux.png") dist.destroy_process_group()

Save to run_flux.py and launch it with torchrun.

# Use --nproc_per_node to specify the number of GPUs torchrun --nproc_per_node=2 run_flux.py

Inference speed is reduced to 8.20 seconds compared to the baseline, or 3.21x faster, with 2 NVIDIA L20 GPUs. On 4 L20s, inference speed is 3.90 seconds, or 6.75x faster.

The code sample below combines First Block Cache and Context Parallelism for the fastest inference speed.

importtimeimporttorchimporttorch.distributedasdistfromdiffusersimportHunyuanVideoPipeline, HunyuanVideoTransformer3DModelfromdiffusers.utilsimportexport_to_videodist.init_process_group() torch.cuda.set_device(dist.get_rank()) model_id="tencent/HunyuanVideo"transformer=HunyuanVideoTransformer3DModel.from_pretrained( model_id, subfolder="transformer", torch_dtype=torch.bfloat16, revision="refs/pr/18", ) pipe=HunyuanVideoPipeline.from_pretrained( model_id, transformer=transformer, torch_dtype=torch.float16, revision="refs/pr/18", ).to("cuda") frompara_attn.context_parallelimportinit_context_parallel_meshfrompara_attn.context_parallel.diffusers_adaptersimportparallelize_pipefrompara_attn.parallel_vae.diffusers_adaptersimportparallelize_vaemesh=init_context_parallel_mesh( pipe.device.type, ) parallelize_pipe( pipe, mesh=mesh, ) parallelize_vae(pipe.vae, mesh=mesh._flatten()) frompara_attn.first_block_cache.diffusers_adaptersimportapply_cache_on_pipeapply_cache_on_pipe(pipe) # from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, float8_weight_only## torch._inductor.config.reorder_for_compute_comm_overlap = True## quantize_(pipe.text_encoder, float8_weight_only())# quantize_(pipe.transformer, float8_dynamic_activation_float8_weight())# pipe.transformer = torch.compile(# pipe.transformer, mode="max-autotune-no-cudagraphs",# )# Enable memory savingspipe.vae.enable_tiling() # pipe.enable_model_cpu_offload(gpu_id=dist.get_rank())# pipe.enable_sequential_cpu_offload(gpu_id=dist.get_rank())foriinrange(2): begin=time.time() output=pipe( prompt="A cat walks on the grass, realistic", height=720, width=1280, num_frames=129, num_inference_steps=1ifi==0else30, output_type="pil"ifdist.get_rank() ==0else"pt", ).frames[0] end=time.time() ifdist.get_rank() ==0: ifi==0: print(f"Warm up time: {end-begin:.2f}s") else: print(f"Time: {end-begin:.2f}s") ifdist.get_rank() ==0: print("Saving video to hunyuan_video.mp4") export_to_video(output, "hunyuan_video.mp4", fps=15) dist.destroy_process_group()

Save to run_hunyuan_video.py and launch it with torchrun.

# Use --nproc_per_node to specify the number of GPUs torchrun --nproc_per_node=8 run_hunyuan_video.py

Inference speed is reduced to 649.23 seconds compared to the baseline, or 5.66x faster, with 8 NVIDIA L20 GPUs.

Benchmarks

GPU TypeNumber of GPUsOptimizationsWall Time (s)Speedup
NVIDIA L201Baseline26.361.00x
NVIDIA L201FBCache (rdt=0.08)17.011.55x
NVIDIA L201FP8 DQ13.401.96x
NVIDIA L201FBCache (rdt=0.12) + FP8 DQ7.563.48x
NVIDIA L202FBCache (rdt=0.12) + FP8 DQ + CP4.925.35x
NVIDIA L204FBCache (rdt=0.12) + FP8 DQ + CP3.906.75x
GPU TypeNumber of GPUsOptimizationsWall Time (s)Speedup
NVIDIA L201Baseline3675.711.00x
NVIDIA L201FBCache2271.061.62x
NVIDIA L202FBCache + CP1132.903.24x
NVIDIA L204FBCache + CP718.155.12x
NVIDIA L208FBCache + CP649.235.66x
close