mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
review comments resolved
This commit is contained in:
@@ -708,9 +708,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model in FlashPack format. FlashPack is a binary format that allows for faster
|
||||
loading. Requires the `flashpack` library to be installed.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format.
|
||||
FlashPack is a binary format that allows for faster loading.
|
||||
Requires the `flashpack` library to be installed.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -741,42 +742,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if use_flashpack:
|
||||
if not is_main_process:
|
||||
return
|
||||
|
||||
try:
|
||||
from flashpack import pack_to_file
|
||||
import json # Ensure json is imported
|
||||
except ImportError:
|
||||
raise ImportError("The `use_flashpack=True` argument requires the `flashpack` library.")
|
||||
|
||||
flashpack_weights_name = _add_variant("model.flashpack", variant)
|
||||
save_path = os.path.join(save_directory, flashpack_weights_name)
|
||||
# Define the config path - this is what your benchmark script is looking for
|
||||
config_save_path = os.path.join(save_directory, "flashpack_config.json")
|
||||
|
||||
try:
|
||||
target_dtype = getattr(self, "dtype", None)
|
||||
logger.warning(f"Dtype used: {target_dtype}")
|
||||
# 1. Save the binary weights
|
||||
pack_to_file(self, save_path, target_dtype=target_dtype)
|
||||
|
||||
# 2. Save the metadata config
|
||||
if hasattr(self, "config"):
|
||||
try:
|
||||
# Attempt to get dictionary representation
|
||||
if hasattr(self.config, "to_dict"):
|
||||
config_data = self.config.to_dict()
|
||||
else:
|
||||
config_data = dict(self.config)
|
||||
|
||||
with open(config_save_path, "w") as f:
|
||||
json.dump(config_data, f, indent=4)
|
||||
except Exception as config_err:
|
||||
logger.warning(f"Weights saved but config serialization failed: {config_err}")
|
||||
|
||||
logger.info(f"Model weights saved in FlashPack format at {save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save weights in FlashPack format: {e}")
|
||||
|
||||
from ..utils.flashpack_utils import save_flashpack
|
||||
save_flashpack(
|
||||
self,
|
||||
save_directory,
|
||||
variant=variant,
|
||||
)
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", None)
|
||||
@@ -982,10 +953,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
||||
weights. If set to `False`, `safetensors` weights are not loaded.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
|
||||
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file
|
||||
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
|
||||
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
|
||||
flashpack`.
|
||||
the standard loading path (for example, `safetensors`).
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
@@ -1252,65 +1222,43 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# If flashpack is not available or the file cannot be loaded, we fall back to
|
||||
# the standard loading path (e.g. safetensors or PyTorch).
|
||||
if use_flashpack:
|
||||
weights_name = _add_variant("model.flashpack", variant)
|
||||
|
||||
try:
|
||||
from flashpack import assign_from_file
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
flashpack_weights_name = _add_variant("model.flashpack", variant)
|
||||
try:
|
||||
flashpack_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=flashpack_weights_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
except EnvironmentError:
|
||||
pass
|
||||
else:
|
||||
dtype_orig = None
|
||||
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"{torch_dtype} needs to be a `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
||||
)
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weights_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
except EnvironmentError:
|
||||
resolved_model_file = None
|
||||
with no_init_weights():
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
if resolved_model_file is not None:
|
||||
from ..utils.flashpack_utils import load_flashpack
|
||||
load_flashpack(model, resolved_model_file)
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
model.eval()
|
||||
|
||||
with no_init_weights():
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
if output_loading_info:
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
return model, loading_info
|
||||
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
try:
|
||||
assign_from_file(model, flashpack_file)
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
|
||||
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
|
||||
model = model.to(torch_dtype)
|
||||
|
||||
model.eval()
|
||||
|
||||
if output_loading_info:
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded:
|
||||
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
|
||||
|
||||
83
src/diffusers/utils/flashpack_utils.py
Normal file
83
src/diffusers/utils/flashpack_utils.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
from .import_utils import is_flashpack_available
|
||||
from .logging import get_logger
|
||||
from ..utils import _add_variant
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def save_flashpack(
|
||||
model,
|
||||
save_directory: str,
|
||||
variant: Optional[str] = None,
|
||||
is_main_process: bool = True,
|
||||
):
|
||||
"""
|
||||
Save model weights in FlashPack format along with a metadata config.
|
||||
|
||||
Args:
|
||||
model: Diffusers model instance
|
||||
save_directory (`str`): Directory to save weights
|
||||
variant (`str`, *optional*): Model variant
|
||||
"""
|
||||
if not is_flashpack_available():
|
||||
raise ImportError(
|
||||
"The `use_flashpack=True` argument requires the `flashpack` package. "
|
||||
"Install it with `pip install flashpack`."
|
||||
)
|
||||
|
||||
from flashpack import pack_to_file
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
weights_name = _add_variant("model.flashpack", variant)
|
||||
weights_path = os.path.join(save_directory, weights_name)
|
||||
config_path = os.path.join(save_directory, "flashpack_config.json")
|
||||
|
||||
try:
|
||||
target_dtype = getattr(model, "dtype", None)
|
||||
logger.warning(f"Dtype used for FlashPack save: {target_dtype}")
|
||||
|
||||
# 1. Save binary weights
|
||||
pack_to_file(model, weights_path, target_dtype=target_dtype)
|
||||
|
||||
# 2. Save config metadata (best-effort)
|
||||
if hasattr(model, "config"):
|
||||
try:
|
||||
if hasattr(model.config, "to_dict"):
|
||||
config_data = model.config.to_dict()
|
||||
else:
|
||||
config_data = dict(model.config)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f, indent=4)
|
||||
|
||||
except Exception as config_err:
|
||||
logger.warning(
|
||||
f"FlashPack weights saved, but config serialization failed: {config_err}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save weights in FlashPack format: {e}")
|
||||
raise
|
||||
|
||||
def load_flashpack(model, flashpack_file: str):
|
||||
"""
|
||||
Assign FlashPack weights from a file into an initialized PyTorch model.
|
||||
"""
|
||||
if not is_flashpack_available():
|
||||
raise ImportError(
|
||||
"FlashPack weights require the `flashpack` package. "
|
||||
"Install with `pip install flashpack`."
|
||||
)
|
||||
|
||||
from flashpack import assign_from_file
|
||||
logger.warning(f"Loading FlashPack weights from {flashpack_file}")
|
||||
|
||||
try:
|
||||
assign_from_file(model, flashpack_file)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to load FlashPack weights from {flashpack_file}"
|
||||
) from e
|
||||
@@ -231,7 +231,7 @@ _aiter_available, _aiter_version = _is_package_available("aiter")
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
|
||||
_av_available, _av_version = _is_package_available("av")
|
||||
|
||||
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
@@ -424,6 +424,8 @@ def is_kornia_available():
|
||||
def is_av_available():
|
||||
return _av_available
|
||||
|
||||
def is_flashpack_available():
|
||||
return _flashpack_available
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
@@ -941,6 +943,14 @@ def is_aiter_version(operation: str, version: str):
|
||||
return False
|
||||
return compare_versions(parse(_aiter_version), operation, version)
|
||||
|
||||
@cache
|
||||
def is_flashpack_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current flashpack version to a given reference with an operation.
|
||||
"""
|
||||
if not _flashpack_available:
|
||||
return False
|
||||
return compare_versions(parse(_flashpack_version), operation, version)
|
||||
|
||||
def get_objects_from_module(module):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user