mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Apply style fixes
This commit is contained in:
@@ -6,7 +6,8 @@ import random
|
||||
import re
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Set, Type
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -14,7 +15,6 @@ import torch.distributed as dist
|
||||
from torch.distributed.fsdp import CPUOffload, ShardingStrategy
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from functools import partial
|
||||
|
||||
from .models import UNet2DConditionModel
|
||||
from .pipelines import DiffusionPipeline
|
||||
@@ -412,21 +412,19 @@ def get_fsdp_kwargs_from_accelerator(accelerator) -> dict:
|
||||
kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD
|
||||
else:
|
||||
# FSDP is enabled → use plugin's strategy, or default if None
|
||||
kwargs["sharding_strategy"] = (
|
||||
fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
|
||||
)
|
||||
kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
def wrap_with_fsdp(
|
||||
model: torch.nn.Module,
|
||||
device: Union[str, torch.device],
|
||||
offload: bool = True,
|
||||
use_orig_params: bool = True,
|
||||
limit_all_gathers: bool = True,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
|
||||
model: torch.nn.Module,
|
||||
device: Union[str, torch.device],
|
||||
offload: bool = True,
|
||||
use_orig_params: bool = True,
|
||||
limit_all_gathers: bool = True,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None,
|
||||
) -> FSDP:
|
||||
"""
|
||||
Wrap a model with FSDP using common defaults and optional transformer auto-wrapping.
|
||||
@@ -459,7 +457,7 @@ def wrap_with_fsdp(
|
||||
"cpu_offload": CPUOffload(offload_params=offload) if offload else None,
|
||||
"use_orig_params": use_orig_params,
|
||||
"limit_all_gathers": limit_all_gathers,
|
||||
"auto_wrap_policy": auto_wrap_policy
|
||||
"auto_wrap_policy": auto_wrap_policy,
|
||||
}
|
||||
|
||||
if fsdp_kwargs:
|
||||
|
||||
Reference in New Issue
Block a user