mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add FSDP option for Flux2
This commit is contained in:
@@ -47,6 +47,7 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
@@ -80,8 +81,10 @@ from diffusers.training_utils import (
|
||||
compute_loss_weighting_for_sd3,
|
||||
find_nearest_bucket,
|
||||
free_memory,
|
||||
get_fsdp_kwargs_from_accelerator,
|
||||
offload_models,
|
||||
parse_buckets_string,
|
||||
wrap_with_fsdp,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -722,6 +725,7 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -1219,7 +1223,11 @@ def main(args):
|
||||
if args.bnb_quantization_config_path is not None
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
if args.do_fp8_training:
|
||||
convert_to_float8_training(
|
||||
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
|
||||
@@ -1507,6 +1515,21 @@ def main(args):
|
||||
args.validation_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
# Init FSDP for text encoder
|
||||
if args.fsdp_text_encoder:
|
||||
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
|
||||
text_encoder_fsdp = wrap_with_fsdp(
|
||||
model=text_encoding_pipeline.text_encoder,
|
||||
device=accelerator.device,
|
||||
offload=args.offload,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
|
||||
text_encoding_pipeline.text_encoder = text_encoder_fsdp
|
||||
dist.barrier()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
@@ -1536,6 +1559,8 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if args.remote_text_encoder:
|
||||
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
|
||||
elif args.fsdp_text_encoder:
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
else:
|
||||
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
@@ -1836,15 +1861,42 @@ def main(args):
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
|
||||
if is_fsdp:
|
||||
transformer = unwrap_model(transformer)
|
||||
state_dict = accelerator.get_state_dict(transformer)
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
if is_fsdp:
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
state_dict = {
|
||||
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
transformer_lora_layers = get_peft_model_state_dict(
|
||||
transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers.items()
|
||||
}
|
||||
|
||||
else:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
|
||||
@@ -46,6 +46,7 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
@@ -79,8 +80,10 @@ from diffusers.training_utils import (
|
||||
compute_loss_weighting_for_sd3,
|
||||
find_nearest_bucket,
|
||||
free_memory,
|
||||
get_fsdp_kwargs_from_accelerator,
|
||||
offload_models,
|
||||
parse_buckets_string,
|
||||
wrap_with_fsdp,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -691,6 +694,7 @@ def parse_args(input_args=None):
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -1156,7 +1160,11 @@ def main(args):
|
||||
if args.bnb_quantization_config_path is not None
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
if args.do_fp8_training:
|
||||
convert_to_float8_training(
|
||||
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
|
||||
@@ -1430,6 +1438,21 @@ def main(args):
|
||||
args.validation_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
# Init FSDP for text encoder
|
||||
if args.fsdp_text_encoder:
|
||||
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
|
||||
text_encoder_fsdp = wrap_with_fsdp(
|
||||
model=text_encoding_pipeline.text_encoder,
|
||||
device=accelerator.device,
|
||||
offload=args.offload,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
|
||||
text_encoding_pipeline.text_encoder = text_encoder_fsdp
|
||||
dist.barrier()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
@@ -1461,6 +1484,8 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if args.remote_text_encoder:
|
||||
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
|
||||
elif args.fsdp_text_encoder:
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
else:
|
||||
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
@@ -1759,15 +1784,41 @@ def main(args):
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_fsdp:
|
||||
transformer = unwrap_model(transformer)
|
||||
state_dict = accelerator.get_state_dict(transformer)
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
if is_fsdp:
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
state_dict = {
|
||||
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
transformer_lora_layers = get_peft_model_state_dict(
|
||||
transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers.items()
|
||||
}
|
||||
|
||||
else:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
|
||||
@@ -6,10 +6,15 @@ import random
|
||||
import re
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Set, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from functools import partial
|
||||
|
||||
from .models import UNet2DConditionModel
|
||||
from .pipelines import DiffusionPipeline
|
||||
@@ -394,6 +399,78 @@ def find_nearest_bucket(h, w, bucket_options):
|
||||
return best_bucket_idx
|
||||
|
||||
|
||||
def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
|
||||
"""
|
||||
Extract and convert FSDP config from Accelerator into PyTorch FSDP kwargs.
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
fsdp_plugin = accelerator.state.fsdp_plugin
|
||||
|
||||
if fsdp_plugin is None:
|
||||
# FSDP not enabled in Accelerator
|
||||
kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD
|
||||
else:
|
||||
# FSDP is enabled → use plugin's strategy, or default if None
|
||||
kwargs["sharding_strategy"] = (
|
||||
fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
|
||||
)
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def wrap_with_fsdp(
|
||||
model: torch.nn.Module,
|
||||
device: Union[str, torch.device],
|
||||
offload: bool = True,
|
||||
use_orig_params: bool = True,
|
||||
limit_all_gathers: bool = True,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
|
||||
) -> FSDP:
|
||||
"""
|
||||
Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.
|
||||
|
||||
Args:
|
||||
model: Model to wrap
|
||||
device: Target device (e.g., accelerator.device)
|
||||
offload: Whether to enable CPU parameter offloading
|
||||
use_orig_params: Whether to use original parameters
|
||||
limit_all_gathers: Whether to limit all gathers
|
||||
fsdp_kwargs: FSDP arguments (sharding_strategy, etc.) — usually from Accelerate config
|
||||
transformer_layer_cls: Classes for auto-wrapping (if not using policy from fsdp_kwargs)
|
||||
|
||||
Returns:
|
||||
FSDP-wrapped model
|
||||
"""
|
||||
|
||||
if transformer_layer_cls is None:
|
||||
# Set the default layers if transformer_layer_cls is not provided
|
||||
transformer_layer_cls = type(model.model.language_model.layers[0])
|
||||
|
||||
# Add auto-wrap policy if transformer layers specified
|
||||
auto_wrap_policy = partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={transformer_layer_cls},
|
||||
)
|
||||
|
||||
config = {
|
||||
"device_id": device,
|
||||
"cpu_offload": CPUOffload(offload_params=offload) if offload else None,
|
||||
"use_orig_params": use_orig_params,
|
||||
"limit_all_gathers": limit_all_gathers,
|
||||
"auto_wrap_policy": auto_wrap_policy
|
||||
}
|
||||
|
||||
if fsdp_kwargs:
|
||||
config.update(fsdp_kwargs)
|
||||
|
||||
fsdp_model = FSDP(model, **config)
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
return fsdp_model
|
||||
|
||||
|
||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||
class EMAModel:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user