mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add FSDP option for Flux2
This commit is contained in:
@@ -1281,7 +1281,8 @@ def main(args):
|
||||
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(model), state_dict=state_dict,
|
||||
unwrap_model(model),
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers_to_save = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
@@ -1326,7 +1327,8 @@ def main(args):
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="transformer",
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
|
||||
@@ -1217,7 +1217,8 @@ def main(args):
|
||||
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(model), state_dict=state_dict,
|
||||
unwrap_model(model),
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers_to_save = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
@@ -1262,7 +1263,8 @@ def main(args):
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="transformer",
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
|
||||
@@ -5,13 +5,15 @@ import math
|
||||
import random
|
||||
import re
|
||||
import warnings
|
||||
from accelerate.logging import get_logger
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
Reference in New Issue
Block a user