From 55eaa6efb28c4cee90643b95435b43e432250058 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 22 Jan 2026 16:44:06 +0530 Subject: [PATCH] style --- src/diffusers/models/modeling_utils.py | 20 ++++++++++---------- src/diffusers/pipelines/pipeline_utils.py | 4 ++-- src/diffusers/utils/flashpack_utils.py | 20 +++++++++----------- src/diffusers/utils/import_utils.py | 5 +++++ 4 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 5b0f8a3a0d..c89204dff7 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -709,9 +709,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). use_flashpack (`bool`, *optional*, defaults to `False`): - Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. - FlashPack is a binary format that allows for faster loading. - Requires the `flashpack` library to be installed. + Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a + binary format that allows for faster loading. Requires the `flashpack` library to be installed. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -743,6 +742,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): if not is_main_process: return from ..utils.flashpack_utils import save_flashpack + save_flashpack( self, save_directory, @@ -953,9 +953,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. use_flashpack (`bool`, *optional*, defaults to `False`): - If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file - is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to - the standard loading path (for example, `safetensors`). + If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) + weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack` + file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`). disable_mmap ('bool', *optional*, defaults to 'False'): Whether to disable mmap when loading a Safetensors model. This option can perform better when the model is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -1279,7 +1279,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): logger.warning( "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead." ) - + if resolved_model_file is None and not is_sharded: resolved_model_file = _get_model_file( pretrained_model_name_or_path, @@ -1323,12 +1323,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): if flashpack_file is not None: from ..utils.flashpack_utils import load_flashpack + # Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing # the model with meta tensors. Since FlashPack cannot write into meta tensors, we # explicitly materialize parameters before loading to ensure correctness and parity # with the standard loading path. if any(p.device.type == "meta" for p in model.parameters()): - model.to_empty(device="cpu") + model.to_empty(device="cpu") load_flashpack(model, flashpack_file) model.register_to_config(_name_or_path=pretrained_model_name_or_path) model.eval() @@ -1434,12 +1435,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): if output_loading_info: return model, loading_info - + logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully") return model - # Adapted from `transformers`. @wraps(torch.nn.Module.cuda) def cuda(self, *args, **kwargs): diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 34e42f4286..2152221606 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -270,8 +270,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). use_flashpack (`bool`, *optional*, defaults to `False`): - Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip install - flashpack`. + Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip + install flashpack`. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ diff --git a/src/diffusers/utils/flashpack_utils.py b/src/diffusers/utils/flashpack_utils.py index 14031a7c54..821fe5e7fd 100644 --- a/src/diffusers/utils/flashpack_utils.py +++ b/src/diffusers/utils/flashpack_utils.py @@ -1,12 +1,15 @@ import json import os from typing import Optional + +from ..utils import _add_variant from .import_utils import is_flashpack_available from .logging import get_logger -from ..utils import _add_variant + logger = get_logger(__name__) + def save_flashpack( model, save_directory: str, @@ -54,30 +57,25 @@ def save_flashpack( json.dump(config_data, f, indent=4) except Exception as config_err: - logger.warning( - f"FlashPack weights saved, but config serialization failed: {config_err}" - ) + logger.warning(f"FlashPack weights saved, but config serialization failed: {config_err}") except Exception as e: logger.error(f"Failed to save weights in FlashPack format: {e}") raise + def load_flashpack(model, flashpack_file: str): """ Assign FlashPack weights from a file into an initialized PyTorch model. """ if not is_flashpack_available(): - raise ImportError( - "FlashPack weights require the `flashpack` package. " - "Install with `pip install flashpack`." - ) + raise ImportError("FlashPack weights require the `flashpack` package. Install with `pip install flashpack`.") from flashpack import assign_from_file + logger.warning(f"Loading FlashPack weights from {flashpack_file}") try: assign_from_file(model, flashpack_file) except Exception as e: - raise RuntimeError( - f"Failed to load FlashPack weights from {flashpack_file}" - ) from e \ No newline at end of file + raise RuntimeError(f"Failed to load FlashPack weights from {flashpack_file}") from e diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 2b99e42a26..af6df925d7 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -233,6 +233,7 @@ _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("mo _av_available, _av_version = _is_package_available("av") _flashpack_available, _flashpack_version = _is_package_available("flashpack") + def is_torch_available(): return _torch_available @@ -424,9 +425,11 @@ def is_kornia_available(): def is_av_available(): return _av_available + def is_flashpack_available(): return _flashpack_available + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -943,6 +946,7 @@ def is_aiter_version(operation: str, version: str): return False return compare_versions(parse(_aiter_version), operation, version) + @cache def is_flashpack_version(operation: str, version: str): """ @@ -952,6 +956,7 @@ def is_flashpack_version(operation: str, version: str): return False return compare_versions(parse(_flashpack_version), operation, version) + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects