From e5bb10cfe10dce1e806e442993f4171cc51ad426 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9Cdevanshi00=E2=80=9D?= <“devanshi7309@gmail.com”> Date: Wed, 21 Jan 2026 04:22:50 +0530 Subject: [PATCH] review comments resolved --- src/diffusers/models/modeling_utils.py | 144 ++++++++----------------- src/diffusers/utils/flashpack_utils.py | 83 ++++++++++++++ src/diffusers/utils/import_utils.py | 12 ++- 3 files changed, 140 insertions(+), 99 deletions(-) create mode 100644 src/diffusers/utils/flashpack_utils.py diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index a4bfa47f6d..62fbc41cd5 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -708,9 +708,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). - use_flashpack (`bool`, *optional*, defaults to `False`): - Whether to save the model in FlashPack format. FlashPack is a binary format that allows for faster - loading. Requires the `flashpack` library to be installed. + use_flashpack (`bool`, *optional*, defaults to `False`): + Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. + FlashPack is a binary format that allows for faster loading. + Requires the `flashpack` library to be installed. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -741,42 +742,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): if use_flashpack: if not is_main_process: return - - try: - from flashpack import pack_to_file - import json # Ensure json is imported - except ImportError: - raise ImportError("The `use_flashpack=True` argument requires the `flashpack` library.") - - flashpack_weights_name = _add_variant("model.flashpack", variant) - save_path = os.path.join(save_directory, flashpack_weights_name) - # Define the config path - this is what your benchmark script is looking for - config_save_path = os.path.join(save_directory, "flashpack_config.json") - - try: - target_dtype = getattr(self, "dtype", None) - logger.warning(f"Dtype used: {target_dtype}") - # 1. Save the binary weights - pack_to_file(self, save_path, target_dtype=target_dtype) - - # 2. Save the metadata config - if hasattr(self, "config"): - try: - # Attempt to get dictionary representation - if hasattr(self.config, "to_dict"): - config_data = self.config.to_dict() - else: - config_data = dict(self.config) - - with open(config_save_path, "w") as f: - json.dump(config_data, f, indent=4) - except Exception as config_err: - logger.warning(f"Weights saved but config serialization failed: {config_err}") - - logger.info(f"Model weights saved in FlashPack format at {save_path}") - except Exception as e: - logger.error(f"Failed to save weights in FlashPack format: {e}") - + from ..utils.flashpack_utils import save_flashpack + save_flashpack( + self, + save_directory, + variant=variant, + ) if push_to_hub: commit_message = kwargs.pop("commit_message", None) private = kwargs.pop("private", None) @@ -982,10 +953,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. use_flashpack (`bool`, *optional*, defaults to `False`): - If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file + If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to - the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install - flashpack`. + the standard loading path (for example, `safetensors`). disable_mmap ('bool', *optional*, defaults to 'False'): Whether to disable mmap when loading a Safetensors model. This option can perform better when the model is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -1252,65 +1222,43 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): # If flashpack is not available or the file cannot be loaded, we fall back to # the standard loading path (e.g. safetensors or PyTorch). if use_flashpack: + weights_name = _add_variant("model.flashpack", variant) + try: - from flashpack import assign_from_file - except ImportError: - pass - else: - flashpack_weights_name = _add_variant("model.flashpack", variant) - try: - flashpack_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=flashpack_weights_name, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - commit_hash=commit_hash, - ) - except EnvironmentError: - pass - else: - dtype_orig = None - if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None): - if not isinstance(torch_dtype, torch.dtype): - raise ValueError( - f"{torch_dtype} needs to be a `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." - ) - dtype_orig = cls._set_default_torch_dtype(torch_dtype) + resolved_model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except EnvironmentError: + resolved_model_file = None + with no_init_weights(): + model = cls.from_config(config, **unused_kwargs) + if resolved_model_file is not None: + from ..utils.flashpack_utils import load_flashpack + load_flashpack(model, resolved_model_file) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + model.eval() - with no_init_weights(): - model = cls.from_config(config, **unused_kwargs) + if output_loading_info: + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + return model, loading_info - if dtype_orig is not None: - torch.set_default_dtype(dtype_orig) - - try: - assign_from_file(model, flashpack_file) - model.register_to_config(_name_or_path=pretrained_model_name_or_path) - - if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None): - model = model.to(torch_dtype) - - model.eval() - - if output_loading_info: - loading_info = { - "missing_keys": [], - "unexpected_keys": [], - "mismatched_keys": [], - "error_msgs": [], - } - return model, loading_info - - return model - - except Exception: - pass + return model + # in the case it is sharded, we have already the index if is_sharded: resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( diff --git a/src/diffusers/utils/flashpack_utils.py b/src/diffusers/utils/flashpack_utils.py new file mode 100644 index 0000000000..14031a7c54 --- /dev/null +++ b/src/diffusers/utils/flashpack_utils.py @@ -0,0 +1,83 @@ +import json +import os +from typing import Optional +from .import_utils import is_flashpack_available +from .logging import get_logger +from ..utils import _add_variant + +logger = get_logger(__name__) + +def save_flashpack( + model, + save_directory: str, + variant: Optional[str] = None, + is_main_process: bool = True, +): + """ + Save model weights in FlashPack format along with a metadata config. + + Args: + model: Diffusers model instance + save_directory (`str`): Directory to save weights + variant (`str`, *optional*): Model variant + """ + if not is_flashpack_available(): + raise ImportError( + "The `use_flashpack=True` argument requires the `flashpack` package. " + "Install it with `pip install flashpack`." + ) + + from flashpack import pack_to_file + + os.makedirs(save_directory, exist_ok=True) + + weights_name = _add_variant("model.flashpack", variant) + weights_path = os.path.join(save_directory, weights_name) + config_path = os.path.join(save_directory, "flashpack_config.json") + + try: + target_dtype = getattr(model, "dtype", None) + logger.warning(f"Dtype used for FlashPack save: {target_dtype}") + + # 1. Save binary weights + pack_to_file(model, weights_path, target_dtype=target_dtype) + + # 2. Save config metadata (best-effort) + if hasattr(model, "config"): + try: + if hasattr(model.config, "to_dict"): + config_data = model.config.to_dict() + else: + config_data = dict(model.config) + + with open(config_path, "w") as f: + json.dump(config_data, f, indent=4) + + except Exception as config_err: + logger.warning( + f"FlashPack weights saved, but config serialization failed: {config_err}" + ) + + except Exception as e: + logger.error(f"Failed to save weights in FlashPack format: {e}") + raise + +def load_flashpack(model, flashpack_file: str): + """ + Assign FlashPack weights from a file into an initialized PyTorch model. + """ + if not is_flashpack_available(): + raise ImportError( + "FlashPack weights require the `flashpack` package. " + "Install with `pip install flashpack`." + ) + + from flashpack import assign_from_file + logger.warning(f"Loading FlashPack weights from {flashpack_file}") + + try: + assign_from_file(model, flashpack_file) + except Exception as e: + raise RuntimeError( + f"Failed to load FlashPack weights from {flashpack_file}" + ) from e \ No newline at end of file diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 425c360a31..2b99e42a26 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -231,7 +231,7 @@ _aiter_available, _aiter_version = _is_package_available("aiter") _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) _av_available, _av_version = _is_package_available("av") - +_flashpack_available, _flashpack_version = _is_package_available("flashpack") def is_torch_available(): return _torch_available @@ -424,6 +424,8 @@ def is_kornia_available(): def is_av_available(): return _av_available +def is_flashpack_available(): + return _flashpack_available # docstyle-ignore FLAX_IMPORT_ERROR = """ @@ -941,6 +943,14 @@ def is_aiter_version(operation: str, version: str): return False return compare_versions(parse(_aiter_version), operation, version) +@cache +def is_flashpack_version(operation: str, version: str): + """ + Compares the current flashpack version to a given reference with an operation. + """ + if not _flashpack_available: + return False + return compare_versions(parse(_flashpack_version), operation, version) def get_objects_from_module(module): """