1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add FSDP option for Flux2

This commit is contained in:
js1234567
2025-12-23 15:56:13 +08:00
parent 647c66aaf3
commit f931ec31a5
3 changed files with 104 additions and 41 deletions

View File

@@ -1271,19 +1271,42 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
transformer_lora_layers_to_save = None
modules_to_save = {}
if is_fsdp:
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
state_dict = accelerator.get_state_dict(models)
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if accelerator.is_main_process:
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(model), state_dict=state_dict,
)
transformer_lora_layers_to_save = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers_to_save.items()
}
modules_to_save["transformer"] = model
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
else:
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if accelerator.is_main_process:
Flux2Pipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
@@ -1293,13 +1316,19 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
while len(models) > 0:
model = models.pop()
if not is_fsdp:
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1802,7 +1831,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if accelerator.is_main_process or is_fsdp:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1861,7 +1890,6 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
is_fsdp = accelerator.state.fsdp_plugin is not None
if is_fsdp:
transformer = unwrap_model(transformer)

View File

@@ -1208,19 +1208,41 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
transformer_lora_layers_to_save = None
modules_to_save = {}
if is_fsdp:
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
state_dict = accelerator.get_state_dict(models)
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if accelerator.is_main_process:
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(model), state_dict=state_dict,
)
transformer_lora_layers_to_save = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers_to_save.items()
}
modules_to_save["transformer"] = model
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
else:
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if accelerator.is_main_process:
Flux2Pipeline.save_lora_weights(
output_dir,
transformer_lora_layers=transformer_lora_layers_to_save,
@@ -1230,13 +1252,19 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
while len(models) > 0:
model = models.pop()
if not is_fsdp:
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1725,7 +1753,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if accelerator.is_main_process or is_fsdp:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:

View File

@@ -5,16 +5,17 @@ import math
import random
import re
import warnings
from accelerate.logging import get_logger
from contextlib import contextmanager
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
if getattr(torch, "distributed", None) is not None:
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from .models import UNet2DConditionModel
from .pipelines import DiffusionPipeline
@@ -405,6 +406,11 @@ def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
"""
kwargs = {}
fsdp_state = getattr(accelerator.state, "fsdp_plugin", None)
if fsdp_state is None:
raise ValueError("Accelerate isn't configured to handle FSDP. Please update your installation.")
fsdp_plugin = accelerator.state.fsdp_plugin
if fsdp_plugin is None:
@@ -442,9 +448,12 @@ def wrap_with_fsdp(
FSDP-wrapped model
"""
logger = get_logger(__name__)
if transformer_layer_cls is None:
# Set the default layers if transformer_layer_cls is not provided
transformer_layer_cls = type(model.model.language_model.layers[0])
logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}")
# Add auto-wrap policy if transformer layers specified
auto_wrap_policy = partial(
@@ -464,8 +473,6 @@ def wrap_with_fsdp(
config.update(fsdp_kwargs)
fsdp_model = FSDP(model, **config)
if dist.is_initialized():
dist.barrier()
return fsdp_model