1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add FSDP option for Flux2

This commit is contained in:
js1234567
2025-12-23 16:23:48 +08:00
parent f931ec31a5
commit 8bce38c086
3 changed files with 11 additions and 5 deletions

View File

@@ -1281,7 +1281,8 @@ def main(args):
if accelerator.is_main_process:
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(model), state_dict=state_dict,
unwrap_model(model),
state_dict=state_dict,
)
transformer_lora_layers_to_save = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
@@ -1326,7 +1327,8 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer",
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)

View File

@@ -1217,7 +1217,8 @@ def main(args):
if accelerator.is_main_process:
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(model), state_dict=state_dict,
unwrap_model(model),
state_dict=state_dict,
)
transformer_lora_layers_to_save = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
@@ -1262,7 +1263,8 @@ def main(args):
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer",
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)

View File

@@ -5,13 +5,15 @@ import math
import random
import re
import warnings
from accelerate.logging import get_logger
from contextlib import contextmanager
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import numpy as np
import torch
from accelerate.logging import get_logger
if getattr(torch, "distributed", None) is not None:
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP