mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add FSDP option for Flux2
This commit is contained in:
@@ -1271,19 +1271,42 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
|
||||
if is_fsdp:
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
state_dict = accelerator.get_state_dict(models)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(model), state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers_to_save = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers_to_save.items()
|
||||
}
|
||||
modules_to_save["transformer"] = model
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
else:
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
@@ -1293,13 +1316,19 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if not is_fsdp:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1802,7 +1831,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or is_fsdp:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
@@ -1861,7 +1890,6 @@ def main(args):
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
|
||||
if is_fsdp:
|
||||
transformer = unwrap_model(transformer)
|
||||
|
||||
@@ -1208,19 +1208,41 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
if is_fsdp:
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
state_dict = accelerator.get_state_dict(models)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(model), state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers_to_save = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers_to_save.items()
|
||||
}
|
||||
modules_to_save["transformer"] = model
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
else:
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
@@ -1230,13 +1252,19 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if not is_fsdp:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1725,7 +1753,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or is_fsdp:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
|
||||
@@ -5,16 +5,17 @@ import math
|
||||
import random
|
||||
import re
|
||||
import warnings
|
||||
from accelerate.logging import get_logger
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
|
||||
from .models import UNet2DConditionModel
|
||||
from .pipelines import DiffusionPipeline
|
||||
@@ -405,6 +406,11 @@ def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
|
||||
"""
|
||||
|
||||
kwargs = {}
|
||||
fsdp_state = getattr(accelerator.state, "fsdp_plugin", None)
|
||||
|
||||
if fsdp_state is None:
|
||||
raise ValueError("Accelerate isn't configured to handle FSDP. Please update your installation.")
|
||||
|
||||
fsdp_plugin = accelerator.state.fsdp_plugin
|
||||
|
||||
if fsdp_plugin is None:
|
||||
@@ -442,9 +448,12 @@ def wrap_with_fsdp(
|
||||
FSDP-wrapped model
|
||||
"""
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if transformer_layer_cls is None:
|
||||
# Set the default layers if transformer_layer_cls is not provided
|
||||
transformer_layer_cls = type(model.model.language_model.layers[0])
|
||||
logger.info(f"transformer_layer_cls is not provided, auto-inferred as {transformer_layer_cls.__name__}")
|
||||
|
||||
# Add auto-wrap policy if transformer layers specified
|
||||
auto_wrap_policy = partial(
|
||||
@@ -464,8 +473,6 @@ def wrap_with_fsdp(
|
||||
config.update(fsdp_kwargs)
|
||||
|
||||
fsdp_model = FSDP(model, **config)
|
||||
if dist.is_initialized():
|
||||
dist.barrier()
|
||||
return fsdp_model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user