1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Apply style fixes

This commit is contained in:
github-actions[bot]
2025-12-22 05:42:17 +00:00
parent 0052b21f52
commit 647c66aaf3

View File

@@ -6,7 +6,8 @@ import random
import re
import warnings
from contextlib import contextmanager
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Set, Type
from functools import partial
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
import numpy as np
import torch
@@ -14,7 +15,6 @@ import torch.distributed as dist
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial
from .models import UNet2DConditionModel
from .pipelines import DiffusionPipeline
@@ -412,21 +412,19 @@ def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD
else:
# FSDP is enabled → use plugin's strategy, or default if None
kwargs["sharding_strategy"] = (
fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
)
kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
return kwargs
def wrap_with_fsdp(
model: torch.nn.Module,
device: Union[str, torch.device],
offload: bool = True,
use_orig_params: bool = True,
limit_all_gathers: bool = True,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
model: torch.nn.Module,
device: Union[str, torch.device],
offload: bool = True,
use_orig_params: bool = True,
limit_all_gathers: bool = True,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
) -> FSDP:
"""
Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.
@@ -459,7 +457,7 @@ def wrap_with_fsdp(
"cpu_offload": CPUOffload(offload_params=offload) if offload else None,
"use_orig_params": use_orig_params,
"limit_all_gathers": limit_all_gathers,
"auto_wrap_policy": auto_wrap_policy
"auto_wrap_policy": auto_wrap_policy,
}
if fsdp_kwargs: