1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[docs] improve distributed inference cp docs. (#12810)

* improve distributed inference cp docs.

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
Sayak Paul
2025-12-11 00:25:07 +08:00
committed by GitHub
parent be3c2a0667
commit 6708f5c76d

View File

@@ -237,6 +237,8 @@ By selectively loading and unloading the models you need at a given stage and sh
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible.
### Ring Attention
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
@@ -245,40 +247,60 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf
```py
import torch
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig
try:
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
pipeline.transformer.set_attention_backend("flash")
return device
def main():
device = setup_distributed()
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
)
pipeline.transformer.set_attention_backend("_native_cudnn")
cp_config = ContextParallelConfig(ring_degree=world_size)
pipeline.transformer.enable_parallelism(config=cp_config)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
# Must specify generator so all ranks start with same latents (or pass your own)
generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
if rank == 0:
image.save("output.png")
image = pipeline(
prompt,
guidance_scale=3.5,
num_inference_steps=50,
generator=generator,
).images[0]
except Exception as e:
print(f"An error occurred: {e}")
torch.distributed.breakpoint()
raise
if dist.get_rank() == 0:
image.save(f"output.png")
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()
```
The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available.
/```shell
`torchrun --nproc-per-node 2 above_script.py`.
/```
### Ulysses Attention
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
@@ -288,5 +310,26 @@ finally:
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
```py
# Depending on the number of GPUs available.
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
```
### parallel_config
Pass `parallel_config` during model initialization to enable context parallelism.
```py
CKPT_ID = "black-forest-labs/FLUX.1-dev"
cp_config = ContextParallelConfig(ring_degree=2)
transformer = AutoModel.from_pretrained(
CKPT_ID,
subfolder="transformer",
torch_dtype=torch.bfloat16,
parallel_config=cp_config
)
pipeline = DiffusionPipeline.from_pretrained(
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
).to(device)
```