From f931ec31a577cf0283be54a166c37e4dcd1ee800 Mon Sep 17 00:00:00 2001 From: js1234567 Date: Tue, 23 Dec 2025 15:56:13 +0800 Subject: [PATCH] Add FSDP option for Flux2 --- .../dreambooth/train_dreambooth_lora_flux2.py | 64 +++++++++++++------ .../train_dreambooth_lora_flux2_img2img.py | 62 +++++++++++++----- src/diffusers/training_utils.py | 19 ++++-- 3 files changed, 104 insertions(+), 41 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 1bd95a3264..e25c7f1669 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -1271,19 +1271,42 @@ def main(args): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} + transformer_lora_layers_to_save = None + modules_to_save = {} + + if is_fsdp: for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + state_dict = accelerator.get_state_dict(models) - # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if accelerator.is_main_process: + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(model), state_dict=state_dict, + ) + transformer_lora_layers_to_save = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers_to_save.items() + } + modules_to_save["transformer"] = model + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + else: + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + modules_to_save = {} + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + if accelerator.is_main_process: Flux2Pipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, @@ -1293,13 +1316,19 @@ def main(args): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not is_fsdp: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) @@ -1802,7 +1831,7 @@ def main(args): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or is_fsdp: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: @@ -1861,7 +1890,6 @@ def main(args): # Save the lora layers accelerator.wait_for_everyone() - is_fsdp = accelerator.state.fsdp_plugin is not None if is_fsdp: transformer = unwrap_model(transformer) diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index d5372e01a3..2062994a0d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -1208,19 +1208,41 @@ def main(args): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): - if accelerator.is_main_process: - transformer_lora_layers_to_save = None - modules_to_save = {} + transformer_lora_layers_to_save = None + modules_to_save = {} + if is_fsdp: for model in models: - if isinstance(model, type(unwrap_model(transformer))): - transformer_lora_layers_to_save = get_peft_model_state_dict(model) - modules_to_save["transformer"] = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + state_dict = accelerator.get_state_dict(models) - # make sure to pop weight so that corresponding model is not saved again - weights.pop() + if accelerator.is_main_process: + transformer_lora_layers_to_save = get_peft_model_state_dict( + unwrap_model(model), state_dict=state_dict, + ) + transformer_lora_layers_to_save = { + k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v + for k, v in transformer_lora_layers_to_save.items() + } + modules_to_save["transformer"] = model + # make sure to pop weight so that corresponding model is not saved again + if weights: + weights.pop() + else: + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + modules_to_save = {} + for model in models: + if isinstance(model, type(unwrap_model(transformer))): + transformer_lora_layers_to_save = get_peft_model_state_dict(model) + modules_to_save["transformer"] = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + if accelerator.is_main_process: Flux2Pipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, @@ -1230,13 +1252,19 @@ def main(args): def load_model_hook(models, input_dir): transformer_ = None - while len(models) > 0: - model = models.pop() + if not is_fsdp: + while len(models) > 0: + model = models.pop() - if isinstance(model, type(unwrap_model(transformer))): - transformer_ = model - else: - raise ValueError(f"unexpected save model: {model.__class__}") + if isinstance(unwrap_model(model), type(unwrap_model(transformer))): + transformer_ = unwrap_model(model) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + else: + transformer_ = Flux2Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", + ) + transformer_.add_adapter(transformer_lora_config) lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir) @@ -1725,7 +1753,7 @@ def main(args): progress_bar.update(1) global_step += 1 - if accelerator.is_main_process: + if accelerator.is_main_process or is_fsdp: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 9407909cf0..56e5fe4e5a 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -5,16 +5,17 @@ import math import random import re import warnings +from accelerate.logging import get_logger from contextlib import contextmanager from functools import partial from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import numpy as np import torch -import torch.distributed as dist -from torch.distributed.fsdp import CPUOffload, ShardingStrategy -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +if getattr(torch, "distributed", None) is not None: + from torch.distributed.fsdp import CPUOffload, ShardingStrategy + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from .models import UNet2DConditionModel from .pipelines import DiffusionPipeline @@ -405,6 +406,11 @@ def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: """ kwargs = {} + fsdp_state = getattr(accelerator.state, "fsdp_plugin", None) + + if fsdp_state is None: + raise ValueError("Accelerate isn't configured to handle FSDP. Please update your installation.") + fsdp_plugin = accelerator.state.fsdp_plugin if fsdp_plugin is None: @@ -442,9 +448,12 @@ def wrap_with_fsdp( FSDP-wrapped model """ + logger = get_logger(__name__) + if transformer_layer_cls is None: # Set the default layers if transformer_layer_cls is not provided transformer_layer_cls = type(model.model.language_model.layers[0]) + logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}") # Add auto-wrap policy if transformer layers specified auto_wrap_policy = partial( @@ -464,8 +473,6 @@ def wrap_with_fsdp( config.update(fsdp_kwargs) fsdp_model = FSDP(model, **config) - if dist.is_initialized(): - dist.barrier() return fsdp_model