From 647c66aaf3bdba17b4601d9ff971da2e8ce92e50 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 22 Dec 2025 05:42:17 +0000 Subject: [PATCH] Apply style fixes --- src/diffusers/training_utils.py | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 7edd21be24..9407909cf0 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -6,7 +6,8 @@ import random import re import warnings from contextlib import contextmanager -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Set, Type +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import numpy as np import torch @@ -14,7 +15,6 @@ import torch.distributed as dist from torch.distributed.fsdp import CPUOffload, ShardingStrategy from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from functools import partial from .models import UNet2DConditionModel from .pipelines import DiffusionPipeline @@ -412,21 +412,19 @@ def get_fsdp_kwargs_from_accelerator(accelerator) -> dict: kwargs["sharding_strategy"] = ShardingStrategy.FULL_SHARD else: # FSDP is enabled → use plugin's strategy, or default if None - kwargs["sharding_strategy"] = ( - fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD - ) + kwargs["sharding_strategy"] = fsdp_plugin.sharding_strategy or ShardingStrategy.FULL_SHARD return kwargs def wrap_with_fsdp( - model: torch.nn.Module, - device: Union[str, torch.device], - offload: bool = True, - use_orig_params: bool = True, - limit_all_gathers: bool = True, - fsdp_kwargs: Optional[Dict[str, Any]] = None, - transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None, + model: torch.nn.Module, + device: Union[str, torch.device], + offload: bool = True, + use_orig_params: bool = True, + limit_all_gathers: bool = True, + fsdp_kwargs: Optional[Dict[str, Any]] = None, + transformer_layer_cls: Optional[Set[Type[torch.nn.Module]]] = None, ) -> FSDP: """ Wrap a model with FSDP using common defaults and optional transformer auto-wrapping. @@ -459,7 +457,7 @@ def wrap_with_fsdp( "cpu_offload": CPUOffload(offload_params=offload) if offload else None, "use_orig_params": use_orig_params, "limit_all_gathers": limit_all_gathers, - "auto_wrap_policy": auto_wrap_policy + "auto_wrap_policy": auto_wrap_policy, } if fsdp_kwargs: