1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add FSDP option for Flux2

This commit is contained in:
js1234567
2025-12-18 19:55:36 +08:00
parent 55463f7ace
commit c766e27c77
3 changed files with 197 additions and 17 deletions

View File

@@ -47,6 +47,7 @@ from pathlib import Path
import numpy as np
import torch
import torch.distributed as dist
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -80,8 +81,10 @@ from diffusers.training_utils import (
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -722,6 +725,7 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1219,7 +1223,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)
is_fsdp = accelerator.state.fsdp_plugin is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1507,6 +1515,21 @@ def main(args):
args.validation_prompt, text_encoding_pipeline
)
# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)
text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1536,6 +1559,8 @@ def main(args):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1836,15 +1861,42 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
is_fsdp = accelerator.state.fsdp_plugin is not None
if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}
else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(

View File

@@ -46,6 +46,7 @@ from pathlib import Path
import numpy as np
import torch
import torch.distributed as dist
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -79,8 +80,10 @@ from diffusers.training_utils import (
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -691,6 +694,7 @@ def parse_args(input_args=None):
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1156,7 +1160,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)
is_fsdp = accelerator.state.fsdp_plugin is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1430,6 +1438,21 @@ def main(args):
args.validation_prompt, text_encoding_pipeline
)
# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)
text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1461,6 +1484,8 @@ def main(args):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1759,15 +1784,41 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}
else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(

View File

@@ -6,10 +6,15 @@ import random
import re
import warnings
from contextlib import contextmanager
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Set, Type
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial
from .models import UNet2DConditionModel
from .pipelines import DiffusionPipeline
@@ -394,6 +399,78 @@ def find_nearest_bucket(h, w, bucket_options):
return best_bucket_idx
def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
"""
Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs.
"""
kwargs = {}
fsdp_plugin = accelerator.state.fsdp_plugin
if fsdp_plugin is None:
# FSDP not enabled in Accelerator
kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD
else:
# FSDP is enabled → use plugin's strategy, or default if None
kwargs["sharding_strategy"] = (
fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
)
return kwargs
def wrap_with_fsdp(
model: torch.nn.Module,
device: Union[str, torch.device],
offload: bool = True,
use_orig_params: bool = True,
limit_all_gathers: bool = True,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
) -> FSDP:
"""
Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.
Args:
model: Model to wrap
device: Target device (e.g., accelerator.device)
offload: Whether to enable CPU parameter offloading
use_orig_params: Whether to use original parameters
limit_all_gathers: Whether to limit all gathers
fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config
transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs)
Returns:
FSDP-wrapped model
"""
if transformer_layer_cls is None:
# Set the default layers if transformer_layer_cls is not provided
transformer_layer_cls = type(model.model.language_model.layers[0])
# Add auto-wrap policy if transformer layers specified
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={transformer_layer_cls},
)
config = {
"device_id": device,
"cpu_offload": CPUOffload(offload_params=offload) if offload else None,
"use_orig_params": use_orig_params,
"limit_all_gathers": limit_all_gathers,
"auto_wrap_policy": auto_wrap_policy
}
if fsdp_kwargs:
config.update(fsdp_kwargs)
fsdp_model = FSDP(model, **config)
if dist.is_initialized():
dist.barrier()
return fsdp_model
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""