From af339debf46431694980476d67f97701a142c77d Mon Sep 17 00:00:00 2001 From: js1234567 Date: Wed, 24 Dec 2025 17:11:05 +0800 Subject: [PATCH] Add FSDP option for Flux2 --- examples/dreambooth/README_flux2.md | 2 +- examples/dreambooth/train_dreambooth_lora_flux2.py | 8 ++++---- .../dreambooth/train_dreambooth_lora_flux2_img2img.py | 8 ++++---- src/diffusers/training_utils.py | 5 +---- 4 files changed, 10 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md index 876cdf2705..41a77c3bbc 100644 --- a/examples/dreambooth/README_flux2.md +++ b/examples/dreambooth/README_flux2.md @@ -100,7 +100,7 @@ This way, the text encoder model is not loaded into memory during training. > to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`. ### FSDP Text Encoder Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings. -This way, the memory cost can be distributed in multiple nodes. +This way, it distributes the memory cost across multiple nodes. ### CPU Offloading To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed. ### Latent Caching diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 71ef89a359..6bba0b94b1 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -44,6 +44,7 @@ import shutil import warnings from contextlib import nullcontext from pathlib import Path +from typing import Any import numpy as np import torch @@ -63,7 +64,6 @@ from torchvision import transforms from torchvision.transforms import functional as TF from tqdm.auto import tqdm from transformers import Mistral3ForConditionalGeneration, PixtralProcessor -from typing import Any import diffusers from diffusers import ( @@ -1292,7 +1292,7 @@ def main(args): raise ValueError("No transformer model found in 'models'") # 2) Optionally gather FSDP state dict once - state_dict = accelerator.get_state_dict(models) if is_fsdp else None + state_dict = accelerator.get_state_dict(model) if is_fsdp else None # 3) Only main process materializes the LoRA state dict transformer_lora_layers_to_save = None @@ -1302,8 +1302,8 @@ def main(args): peft_kwargs["state_dict"] = state_dict transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(transformer_model) if is_fsdp else transformer_model - ** peft_kwargs, + unwrap_model(transformer_model) if is_fsdp else transformer_model, + **peft_kwargs, ) if is_fsdp: diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index 48d4000cf8..c22c48ecae 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -43,6 +43,7 @@ import random import shutil from contextlib import nullcontext from pathlib import Path +from typing import Any import numpy as np import torch @@ -61,7 +62,6 @@ from torchvision import transforms from torchvision.transforms import functional as TF from tqdm.auto import tqdm from transformers import Mistral3ForConditionalGeneration, PixtralProcessor -from typing import Any import diffusers from diffusers import ( @@ -75,7 +75,7 @@ from diffusers.optimization import get_scheduler from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor from diffusers.training_utils import ( _collate_lora_metadata, - _to_cpu_contiguous + _to_cpu_contiguous, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, @@ -1229,7 +1229,7 @@ def main(args): raise ValueError("No transformer model found in 'models'") # 2) Optionally gather FSDP state dict once - state_dict = accelerator.get_state_dict(models) if is_fsdp else None + state_dict = accelerator.get_state_dict(model) if is_fsdp else None # 3) Only main process materializes the LoRA state dict transformer_lora_layers_to_save = None @@ -1239,7 +1239,7 @@ def main(args): peft_kwargs["state_dict"] = state_dict transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(transformer_model) if is_fsdp else transformer_model + unwrap_model(transformer_model) if is_fsdp else transformer_model, **peft_kwargs, ) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 90523c4c3c..2d2f26b266 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -403,10 +403,7 @@ def find_nearest_bucket(h, w, bucket_options): def _to_cpu_contiguous(state_dicts) -> dict: - return { - k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v - for k, v in state_dicts.items() - } + return {k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dicts.items()} def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: