mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -56,7 +56,7 @@
|
||||
- local: using-diffusers/overview_techniques
|
||||
title: Overview
|
||||
- local: training/distributed_inference
|
||||
title: Distributed inference with multiple GPUs
|
||||
title: Distributed inference
|
||||
- local: using-diffusers/merge_loras
|
||||
title: Merge LoRAs
|
||||
- local: using-diffusers/scheduler_features
|
||||
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Distributed inference with multiple GPUs
|
||||
# Distributed inference
|
||||
|
||||
On distributed setups, you can run inference across multiple GPUs with 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) or [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), which is useful for generating with multiple prompts in parallel.
|
||||
|
||||
@@ -109,3 +109,131 @@ torchrun run_distributed.py --nproc_per_node=2
|
||||
|
||||
> [!TIP]
|
||||
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
|
||||
|
||||
## Model sharding
|
||||
|
||||
Modern diffusion systems such as [Flux](../api/pipelines/flux) are very large and have multiple models. For example, [Flux.1-Dev](https://hf.co/black-forest-labs/FLUX.1-dev) is made up of two text encoders - [T5-XXL](https://hf.co/google/t5-v1_1-xxl) and [CLIP-L](https://hf.co/openai/clip-vit-large-patch14) - a [diffusion transformer](../api/models/flux_transformer), and a [VAE](../api/models/autoencoderkl). With a model this size, it can be challenging to run inference on consumer GPUs.
|
||||
|
||||
Model sharding is a technique that distributes models across GPUs when the models don't fit on a single GPU. The example below assumes two 16GB GPUs are available for inference.
|
||||
|
||||
Start by computing the text embeddings with the text encoders. Keep the text encoders on two GPUs by setting `device_map="balanced"`. The `balanced` strategy evenly distributes the model on all available GPUs. Use the `max_memory` parameter to allocate the maximum amount of memory for each text encoder on each GPU.
|
||||
|
||||
> [!TIP]
|
||||
> **Only** load the text encoders for this step! The diffusion transformer and VAE are loaded in a later step to preserve memory.
|
||||
|
||||
```py
|
||||
from diffusers import FluxPipeline
|
||||
import torch
|
||||
|
||||
prompt = "a photo of a dog with cat-like look"
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=None,
|
||||
vae=None,
|
||||
device_map="balanced",
|
||||
max_memory={0: "16GB", 1: "16GB"},
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
with torch.no_grad():
|
||||
print("Encoding prompts.")
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
prompt=prompt, prompt_2=None, max_sequence_length=512
|
||||
)
|
||||
```
|
||||
|
||||
Once the text embeddings are computed, remove them from the GPU to make space for the diffusion transformer.
|
||||
|
||||
```py
|
||||
import gc
|
||||
|
||||
def flush():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
del pipeline.text_encoder
|
||||
del pipeline.text_encoder_2
|
||||
del pipeline.tokenizer
|
||||
del pipeline.tokenizer_2
|
||||
del pipeline
|
||||
|
||||
flush()
|
||||
```
|
||||
|
||||
Load the diffusion transformer next which has 12.5B parameters. This time, set `device_map="auto"` to automatically distribute the model across two 16GB GPUs. The `auto` strategy is backed by [Accelerate](https://hf.co/docs/accelerate/index) and available as a part of the [Big Model Inference](https://hf.co/docs/accelerate/concept_guides/big_model_inference) feature. It starts by distributing a model across the fastest device first (GPU) before moving to slower devices like the CPU and hard drive if needed. The trade-off of storing model parameters on slower devices is slower inference latency.
|
||||
|
||||
```py
|
||||
from diffusers import FluxTransformer2DModel
|
||||
import torch
|
||||
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models.
|
||||
|
||||
Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet.
|
||||
|
||||
```py
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", ,
|
||||
text_encoder=None,
|
||||
text_encoder_2=None,
|
||||
tokenizer=None,
|
||||
tokenizer_2=None,
|
||||
vae=None,
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
print("Running denoising.")
|
||||
height, width = 768, 1360
|
||||
latents = pipeline(
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=3.5,
|
||||
height=height,
|
||||
width=width,
|
||||
output_type="latent",
|
||||
).images
|
||||
```
|
||||
|
||||
Remove the pipeline and transformer from memory as they're no longer needed.
|
||||
|
||||
```py
|
||||
del pipeline.transformer
|
||||
del pipeline
|
||||
|
||||
flush()
|
||||
```
|
||||
|
||||
Finally, decode the latents with the VAE into an image. The VAE is typically small enough to be loaded on a single GPU.
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
import torch
|
||||
|
||||
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=torch.bfloat16).to("cuda")
|
||||
vae_scale_factor = 2 ** (len(vae.config.block_out_channels))
|
||||
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
||||
|
||||
with torch.no_grad():
|
||||
print("Running decoding.")
|
||||
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
|
||||
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
|
||||
|
||||
image = vae.decode(latents, return_dict=False)[0]
|
||||
image = image_processor.postprocess(image, output_type="pil")
|
||||
image[0].save("split_transformer.png")
|
||||
```
|
||||
|
||||
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.
|
||||
|
||||
Reference in New Issue
Block a user