1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

review comments resolved

This commit is contained in:
“devanshi00”
2026-01-21 04:22:50 +05:30
parent ec541906c5
commit e5bb10cfe1
3 changed files with 140 additions and 99 deletions

View File

@@ -708,9 +708,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether to save the model in FlashPack format. FlashPack is a binary format that allows for faster
loading. Requires the `flashpack` library to be installed.
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format.
FlashPack is a binary format that allows for faster loading.
Requires the `flashpack` library to be installed.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -741,42 +742,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if use_flashpack:
if not is_main_process:
return
try:
from flashpack import pack_to_file
import json # Ensure json is imported
except ImportError:
raise ImportError("The `use_flashpack=True` argument requires the `flashpack` library.")
flashpack_weights_name = _add_variant("model.flashpack", variant)
save_path = os.path.join(save_directory, flashpack_weights_name)
# Define the config path - this is what your benchmark script is looking for
config_save_path = os.path.join(save_directory, "flashpack_config.json")
try:
target_dtype = getattr(self, "dtype", None)
logger.warning(f"Dtype used: {target_dtype}")
# 1. Save the binary weights
pack_to_file(self, save_path, target_dtype=target_dtype)
# 2. Save the metadata config
if hasattr(self, "config"):
try:
# Attempt to get dictionary representation
if hasattr(self.config, "to_dict"):
config_data = self.config.to_dict()
else:
config_data = dict(self.config)
with open(config_save_path, "w") as f:
json.dump(config_data, f, indent=4)
except Exception as config_err:
logger.warning(f"Weights saved but config serialization failed: {config_err}")
logger.info(f"Model weights saved in FlashPack format at {save_path}")
except Exception as e:
logger.error(f"Failed to save weights in FlashPack format: {e}")
from ..utils.flashpack_utils import save_flashpack
save_flashpack(
self,
save_directory,
variant=variant,
)
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", None)
@@ -982,10 +953,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack) weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
flashpack`.
the standard loading path (for example, `safetensors`).
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -1252,65 +1222,43 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# If flashpack is not available or the file cannot be loaded, we fall back to
# the standard loading path (e.g. safetensors or PyTorch).
if use_flashpack:
weights_name = _add_variant("model.flashpack", variant)
try:
from flashpack import assign_from_file
except ImportError:
pass
else:
flashpack_weights_name = _add_variant("model.flashpack", variant)
try:
flashpack_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=flashpack_weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
except EnvironmentError:
pass
else:
dtype_orig = None
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
if not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be a `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
except EnvironmentError:
resolved_model_file = None
with no_init_weights():
model = cls.from_config(config, **unused_kwargs)
if resolved_model_file is not None:
from ..utils.flashpack_utils import load_flashpack
load_flashpack(model, resolved_model_file)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
model.eval()
with no_init_weights():
model = cls.from_config(config, **unused_kwargs)
if output_loading_info:
loading_info = {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
return model, loading_info
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
try:
assign_from_file(model, flashpack_file)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
model = model.to(torch_dtype)
model.eval()
if output_loading_info:
loading_info = {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
return model, loading_info
return model
except Exception:
pass
return model
# in the case it is sharded, we have already the index
if is_sharded:
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(

View File

@@ -0,0 +1,83 @@
import json
import os
from typing import Optional
from .import_utils import is_flashpack_available
from .logging import get_logger
from ..utils import _add_variant
logger = get_logger(__name__)
def save_flashpack(
model,
save_directory: str,
variant: Optional[str] = None,
is_main_process: bool = True,
):
"""
Save model weights in FlashPack format along with a metadata config.
Args:
model: Diffusers model instance
save_directory (`str`): Directory to save weights
variant (`str`, *optional*): Model variant
"""
if not is_flashpack_available():
raise ImportError(
"The `use_flashpack=True` argument requires the `flashpack` package. "
"Install it with `pip install flashpack`."
)
from flashpack import pack_to_file
os.makedirs(save_directory, exist_ok=True)
weights_name = _add_variant("model.flashpack", variant)
weights_path = os.path.join(save_directory, weights_name)
config_path = os.path.join(save_directory, "flashpack_config.json")
try:
target_dtype = getattr(model, "dtype", None)
logger.warning(f"Dtype used for FlashPack save: {target_dtype}")
# 1. Save binary weights
pack_to_file(model, weights_path, target_dtype=target_dtype)
# 2. Save config metadata (best-effort)
if hasattr(model, "config"):
try:
if hasattr(model.config, "to_dict"):
config_data = model.config.to_dict()
else:
config_data = dict(model.config)
with open(config_path, "w") as f:
json.dump(config_data, f, indent=4)
except Exception as config_err:
logger.warning(
f"FlashPack weights saved, but config serialization failed: {config_err}"
)
except Exception as e:
logger.error(f"Failed to save weights in FlashPack format: {e}")
raise
def load_flashpack(model, flashpack_file: str):
"""
Assign FlashPack weights from a file into an initialized PyTorch model.
"""
if not is_flashpack_available():
raise ImportError(
"FlashPack weights require the `flashpack` package. "
"Install with `pip install flashpack`."
)
from flashpack import assign_from_file
logger.warning(f"Loading FlashPack weights from {flashpack_file}")
try:
assign_from_file(model, flashpack_file)
except Exception as e:
raise RuntimeError(
f"Failed to load FlashPack weights from {flashpack_file}"
) from e

View File

@@ -231,7 +231,7 @@ _aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_av_available, _av_version = _is_package_available("av")
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
def is_torch_available():
return _torch_available
@@ -424,6 +424,8 @@ def is_kornia_available():
def is_av_available():
return _av_available
def is_flashpack_available():
return _flashpack_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
@@ -941,6 +943,14 @@ def is_aiter_version(operation: str, version: str):
return False
return compare_versions(parse(_aiter_version), operation, version)
@cache
def is_flashpack_version(operation: str, version: str):
"""
Compares the current flashpack version to a given reference with an operation.
"""
if not _flashpack_available:
return False
return compare_versions(parse(_flashpack_version), operation, version)
def get_objects_from_module(module):
"""