1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2026-01-22 16:44:06 +05:30
parent b603429ff5
commit 55eaa6efb2
4 changed files with 26 additions and 23 deletions

View File

@@ -709,9 +709,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format.
FlashPack is a binary format that allows for faster loading.
Requires the `flashpack` library to be installed.
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a
binary format that allows for faster loading. Requires the `flashpack` library to be installed.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -743,6 +742,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if not is_main_process:
return
from ..utils.flashpack_utils import save_flashpack
save_flashpack(
self,
save_directory,
@@ -953,9 +953,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`).
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack)
weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack`
file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`).
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -1279,7 +1279,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
if resolved_model_file is None and not is_sharded:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
@@ -1323,12 +1323,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if flashpack_file is not None:
from ..utils.flashpack_utils import load_flashpack
# Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing
# the model with meta tensors. Since FlashPack cannot write into meta tensors, we
# explicitly materialize parameters before loading to ensure correctness and parity
# with the standard loading path.
if any(p.device.type == "meta" for p in model.parameters()):
model.to_empty(device="cpu")
model.to_empty(device="cpu")
load_flashpack(model, flashpack_file)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
model.eval()
@@ -1434,12 +1435,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if output_loading_info:
return model, loading_info
logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully")
return model
# Adapted from `transformers`.
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):

View File

@@ -270,8 +270,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip install
flashpack`.
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip
install flashpack`.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""

View File

@@ -1,12 +1,15 @@
import json
import os
from typing import Optional
from ..utils import _add_variant
from .import_utils import is_flashpack_available
from .logging import get_logger
from ..utils import _add_variant
logger = get_logger(__name__)
def save_flashpack(
model,
save_directory: str,
@@ -54,30 +57,25 @@ def save_flashpack(
json.dump(config_data, f, indent=4)
except Exception as config_err:
logger.warning(
f"FlashPack weights saved, but config serialization failed: {config_err}"
)
logger.warning(f"FlashPack weights saved, but config serialization failed: {config_err}")
except Exception as e:
logger.error(f"Failed to save weights in FlashPack format: {e}")
raise
def load_flashpack(model, flashpack_file: str):
"""
Assign FlashPack weights from a file into an initialized PyTorch model.
"""
if not is_flashpack_available():
raise ImportError(
"FlashPack weights require the `flashpack` package. "
"Install with `pip install flashpack`."
)
raise ImportError("FlashPack weights require the `flashpack` package. Install with `pip install flashpack`.")
from flashpack import assign_from_file
logger.warning(f"Loading FlashPack weights from {flashpack_file}")
try:
assign_from_file(model, flashpack_file)
except Exception as e:
raise RuntimeError(
f"Failed to load FlashPack weights from {flashpack_file}"
) from e
raise RuntimeError(f"Failed to load FlashPack weights from {flashpack_file}") from e

View File

@@ -233,6 +233,7 @@ _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("mo
_av_available, _av_version = _is_package_available("av")
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
def is_torch_available():
return _torch_available
@@ -424,9 +425,11 @@ def is_kornia_available():
def is_av_available():
return _av_available
def is_flashpack_available():
return _flashpack_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -943,6 +946,7 @@ def is_aiter_version(operation: str, version: str):
return False
return compare_versions(parse(_aiter_version), operation, version)
@cache
def is_flashpack_version(operation: str, version: str):
"""
@@ -952,6 +956,7 @@ def is_flashpack_version(operation: str, version: str):
return False
return compare_versions(parse(_flashpack_version), operation, version)
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects