mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
style
This commit is contained in:
@@ -709,9 +709,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format.
|
||||
FlashPack is a binary format that allows for faster loading.
|
||||
Requires the `flashpack` library to be installed.
|
||||
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a
|
||||
binary format that allows for faster loading. Requires the `flashpack` library to be installed.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -743,6 +742,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if not is_main_process:
|
||||
return
|
||||
from ..utils.flashpack_utils import save_flashpack
|
||||
|
||||
save_flashpack(
|
||||
self,
|
||||
save_directory,
|
||||
@@ -953,9 +953,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
||||
weights. If set to `False`, `safetensors` weights are not loaded.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file
|
||||
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
|
||||
the standard loading path (for example, `safetensors`).
|
||||
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack)
|
||||
weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack`
|
||||
file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`).
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
@@ -1279,7 +1279,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
logger.warning(
|
||||
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
||||
)
|
||||
|
||||
|
||||
if resolved_model_file is None and not is_sharded:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -1323,12 +1323,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
if flashpack_file is not None:
|
||||
from ..utils.flashpack_utils import load_flashpack
|
||||
|
||||
# Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing
|
||||
# the model with meta tensors. Since FlashPack cannot write into meta tensors, we
|
||||
# explicitly materialize parameters before loading to ensure correctness and parity
|
||||
# with the standard loading path.
|
||||
if any(p.device.type == "meta" for p in model.parameters()):
|
||||
model.to_empty(device="cpu")
|
||||
model.to_empty(device="cpu")
|
||||
load_flashpack(model, flashpack_file)
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
model.eval()
|
||||
@@ -1434,12 +1435,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
|
||||
|
||||
logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# Adapted from `transformers`.
|
||||
@wraps(torch.nn.Module.cuda)
|
||||
def cuda(self, *args, **kwargs):
|
||||
|
||||
@@ -270,8 +270,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip install
|
||||
flashpack`.
|
||||
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip
|
||||
install flashpack`.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from ..utils import _add_variant
|
||||
from .import_utils import is_flashpack_available
|
||||
from .logging import get_logger
|
||||
from ..utils import _add_variant
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_flashpack(
|
||||
model,
|
||||
save_directory: str,
|
||||
@@ -54,30 +57,25 @@ def save_flashpack(
|
||||
json.dump(config_data, f, indent=4)
|
||||
|
||||
except Exception as config_err:
|
||||
logger.warning(
|
||||
f"FlashPack weights saved, but config serialization failed: {config_err}"
|
||||
)
|
||||
logger.warning(f"FlashPack weights saved, but config serialization failed: {config_err}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save weights in FlashPack format: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def load_flashpack(model, flashpack_file: str):
|
||||
"""
|
||||
Assign FlashPack weights from a file into an initialized PyTorch model.
|
||||
"""
|
||||
if not is_flashpack_available():
|
||||
raise ImportError(
|
||||
"FlashPack weights require the `flashpack` package. "
|
||||
"Install with `pip install flashpack`."
|
||||
)
|
||||
raise ImportError("FlashPack weights require the `flashpack` package. Install with `pip install flashpack`.")
|
||||
|
||||
from flashpack import assign_from_file
|
||||
|
||||
logger.warning(f"Loading FlashPack weights from {flashpack_file}")
|
||||
|
||||
try:
|
||||
assign_from_file(model, flashpack_file)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to load FlashPack weights from {flashpack_file}"
|
||||
) from e
|
||||
raise RuntimeError(f"Failed to load FlashPack weights from {flashpack_file}") from e
|
||||
|
||||
@@ -233,6 +233,7 @@ _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("mo
|
||||
_av_available, _av_version = _is_package_available("av")
|
||||
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
@@ -424,9 +425,11 @@ def is_kornia_available():
|
||||
def is_av_available():
|
||||
return _av_available
|
||||
|
||||
|
||||
def is_flashpack_available():
|
||||
return _flashpack_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -943,6 +946,7 @@ def is_aiter_version(operation: str, version: str):
|
||||
return False
|
||||
return compare_versions(parse(_aiter_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_flashpack_version(operation: str, version: str):
|
||||
"""
|
||||
@@ -952,6 +956,7 @@ def is_flashpack_version(operation: str, version: str):
|
||||
return False
|
||||
return compare_versions(parse(_flashpack_version), operation, version)
|
||||
|
||||
|
||||
def get_objects_from_module(module):
|
||||
"""
|
||||
Returns a dict of object names and values in a module, while skipping private/internal objects
|
||||
|
||||
Reference in New Issue
Block a user