diff --git a/examples/dreambooth/README_flux2.md b/examples/dreambooth/README_flux2.md index 1d17778113..876cdf2705 100644 --- a/examples/dreambooth/README_flux2.md +++ b/examples/dreambooth/README_flux2.md @@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take This way, the text encoder model is not loaded into memory during training. > [!NOTE] > to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`. +### FSDP Text Encoder +Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings. +This way, the memory cost can be distributed in multiple nodes. ### CPU Offloading To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed. ### Latent Caching diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index ff502d9309..71ef89a359 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -47,7 +47,6 @@ from pathlib import Path import numpy as np import torch -import torch.distributed as dist import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -64,6 +63,7 @@ from torchvision import transforms from torchvision.transforms import functional as TF from tqdm.auto import tqdm from transformers import Mistral3ForConditionalGeneration, PixtralProcessor +from typing import Any import diffusers from diffusers import ( @@ -76,6 +76,7 @@ from diffusers import ( from diffusers.optimization import get_scheduler from diffusers.training_utils import ( _collate_lora_metadata, + _to_cpu_contiguous, cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, @@ -96,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + if is_wandb_available(): import wandb @@ -1271,43 +1275,44 @@ def main(args): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(models) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict transformer_lora_layers_to_save = None - modules_to_save = {} - - if is_fsdp: - for model in models: - if isinstance(unwrap_model(model), type(unwrap_model(transformer))): - state_dict = accelerator.get_state_dict(models) - - if accelerator.is_main_process: - transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(model), - state_dict=state_dict, - ) - transformer_lora_layers_to_save = { - k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v - for k, v in transformer_lora_layers_to_save.items() - } - modules_to_save["transformer"] = model - - # make sure to pop weight so that corresponding model is not saved again - if weights: - weights.pop() - else: - if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - if accelerator.is_main_process: + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model + ** peft_kwargs, + ) + + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + Flux2Pipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index cd2d493b18..48d4000cf8 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -46,7 +46,6 @@ from pathlib import Path import numpy as np import torch -import torch.distributed as dist import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -62,6 +61,7 @@ from torchvision import transforms from torchvision.transforms import functional as TF from tqdm.auto import tqdm from transformers import Mistral3ForConditionalGeneration, PixtralProcessor +from typing import Any import diffusers from diffusers import ( @@ -75,6 +75,7 @@ from diffusers.optimization import get_scheduler from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor from diffusers.training_utils import ( _collate_lora_metadata, + _to_cpu_contiguous cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, @@ -96,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module +if getattr(torch, "distributed", None) is not None: + import torch.distributed as dist + if is_wandb_available(): import wandb @@ -1208,42 +1212,44 @@ def main(args): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): + transformer_cls = type(unwrap_model(transformer)) + + # 1) Validate and pick the transformer model + modules_to_save: dict[str, Any] = {} + transformer_model = None + + for model in models: + if isinstance(unwrap_model(model), transformer_cls): + transformer_model = model + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer_model is None: + raise ValueError("No transformer model found in 'models'") + + # 2) Optionally gather FSDP state dict once + state_dict = accelerator.get_state_dict(models) if is_fsdp else None + + # 3) Only main process materializes the LoRA state dict transformer_lora_layers_to_save = None - modules_to_save = {} - if is_fsdp: - for model in models: - if isinstance(unwrap_model(model), type(unwrap_model(transformer))): - state_dict = accelerator.get_state_dict(models) - - if accelerator.is_main_process: - transformer_lora_layers_to_save = get_peft_model_state_dict( - unwrap_model(model), - state_dict=state_dict, - ) - transformer_lora_layers_to_save = { - k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v - for k, v in transformer_lora_layers_to_save.items() - } - modules_to_save["transformer"] = model - - # make sure to pop weight so that corresponding model is not saved again - if weights: - weights.pop() - else: - if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} - for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - if accelerator.is_main_process: + peft_kwargs = {} + if is_fsdp: + peft_kwargs["state_dict"] = state_dict + + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(transformer_model) if is_fsdp else transformer_model + **peft_kwargs, + ) + + if is_fsdp: + transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save) + + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + Flux2Pipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 9b09d2c814..90523c4c3c 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -402,6 +402,13 @@ def find_nearest_bucket(h, w, bucket_options): return best_bucket_idx +def _to_cpu_contiguous(state_dicts) -> dict: + return { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in state_dicts.items() + } + + def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: """ Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs.