mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[docs] improve distributed inference cp docs. (#12810)
* improve distributed inference cp docs. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -237,6 +237,8 @@ By selectively loading and unloading the models you need at a given stage and sh
|
||||
|
||||
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
|
||||
|
||||
Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible.
|
||||
|
||||
### Ring Attention
|
||||
|
||||
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
|
||||
@@ -245,40 +247,60 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
|
||||
from torch import distributed as dist
|
||||
from diffusers import DiffusionPipeline, ContextParallelConfig
|
||||
|
||||
try:
|
||||
torch.distributed.init_process_group("nccl")
|
||||
rank = torch.distributed.get_rank()
|
||||
device = torch.device("cuda", rank % torch.cuda.device_count())
|
||||
def setup_distributed():
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
rank = dist.get_rank()
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
|
||||
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
|
||||
pipeline.transformer.set_attention_backend("flash")
|
||||
return device
|
||||
|
||||
def main():
|
||||
device = setup_distributed()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
|
||||
)
|
||||
pipeline.transformer.set_attention_backend("_native_cudnn")
|
||||
|
||||
cp_config = ContextParallelConfig(ring_degree=world_size)
|
||||
pipeline.transformer.enable_parallelism(config=cp_config)
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
|
||||
|
||||
# Must specify generator so all ranks start with same latents (or pass your own)
|
||||
generator = torch.Generator().manual_seed(42)
|
||||
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
|
||||
|
||||
if rank == 0:
|
||||
image.save("output.png")
|
||||
image = pipeline(
|
||||
prompt,
|
||||
guidance_scale=3.5,
|
||||
num_inference_steps=50,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
torch.distributed.breakpoint()
|
||||
raise
|
||||
if dist.get_rank() == 0:
|
||||
image.save(f"output.png")
|
||||
|
||||
finally:
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available.
|
||||
|
||||
/```shell
|
||||
`torchrun --nproc-per-node 2 above_script.py`.
|
||||
/```
|
||||
|
||||
### Ulysses Attention
|
||||
|
||||
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
|
||||
@@ -288,5 +310,26 @@ finally:
|
||||
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
|
||||
|
||||
```py
|
||||
# Depending on the number of GPUs available.
|
||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
|
||||
```
|
||||
|
||||
### parallel_config
|
||||
|
||||
Pass `parallel_config` during model initialization to enable context parallelism.
|
||||
|
||||
```py
|
||||
CKPT_ID = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
cp_config = ContextParallelConfig(ring_degree=2)
|
||||
transformer = AutoModel.from_pretrained(
|
||||
CKPT_ID,
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
parallel_config=cp_config
|
||||
)
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
|
||||
).to(device)
|
||||
```
|
||||
Reference in New Issue
Block a user