1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

added fal-flashpack support

This commit is contained in:
“devanshi00”
2026-01-19 14:52:15 +05:30
parent 3996788b60
commit ec541906c5
3 changed files with 129 additions and 1 deletions

View File

@@ -675,6 +675,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "10GB",
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -707,6 +708,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether to save the model in FlashPack format. FlashPack is a binary format that allows for faster
loading. Requires the `flashpack` library to be installed.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -734,6 +738,44 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
)
os.makedirs(save_directory, exist_ok=True)
if use_flashpack:
if not is_main_process:
return
try:
from flashpack import pack_to_file
import json # Ensure json is imported
except ImportError:
raise ImportError("The `use_flashpack=True` argument requires the `flashpack` library.")
flashpack_weights_name = _add_variant("model.flashpack", variant)
save_path = os.path.join(save_directory, flashpack_weights_name)
# Define the config path - this is what your benchmark script is looking for
config_save_path = os.path.join(save_directory, "flashpack_config.json")
try:
target_dtype = getattr(self, "dtype", None)
logger.warning(f"Dtype used: {target_dtype}")
# 1. Save the binary weights
pack_to_file(self, save_path, target_dtype=target_dtype)
# 2. Save the metadata config
if hasattr(self, "config"):
try:
# Attempt to get dictionary representation
if hasattr(self.config, "to_dict"):
config_data = self.config.to_dict()
else:
config_data = dict(self.config)
with open(config_save_path, "w") as f:
json.dump(config_data, f, indent=4)
except Exception as config_err:
logger.warning(f"Weights saved but config serialization failed: {config_err}")
logger.info(f"Model weights saved in FlashPack format at {save_path}")
except Exception as e:
logger.error(f"Failed to save weights in FlashPack format: {e}")
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
@@ -939,6 +981,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
flashpack`.
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -982,6 +1029,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
@@ -1200,6 +1248,69 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
else:
# If we are using `use_flashpack`, we try to load the model from flashpack first.
# If flashpack is not available or the file cannot be loaded, we fall back to
# the standard loading path (e.g. safetensors or PyTorch).
if use_flashpack:
try:
from flashpack import assign_from_file
except ImportError:
pass
else:
flashpack_weights_name = _add_variant("model.flashpack", variant)
try:
flashpack_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=flashpack_weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
)
except EnvironmentError:
pass
else:
dtype_orig = None
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
if not isinstance(torch_dtype, torch.dtype):
raise ValueError(
f"{torch_dtype} needs to be a `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
)
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
with no_init_weights():
model = cls.from_config(config, **unused_kwargs)
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
try:
assign_from_file(model, flashpack_file)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
model = model.to(torch_dtype)
model.eval()
if output_loading_info:
loading_info = {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
return model, loading_info
return model
except Exception:
pass
# in the case it is sharded, we have already the index
if is_sharded:
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(

View File

@@ -756,6 +756,7 @@ def load_sub_model(
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
use_flashpack: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
disable_mmap: bool,
@@ -838,6 +839,9 @@ def load_sub_model(
loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors
if is_diffusers_model:
loading_kwargs["use_flashpack"] = use_flashpack
if from_flax:
loading_kwargs["from_flax"] = True

View File

@@ -243,6 +243,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant: Optional[str] = None,
max_shard_size: Optional[Union[int, str]] = None,
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -268,7 +269,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip install
flashpack`.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -340,6 +343,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
save_kwargs = {}
if save_method_accept_safe:
@@ -349,6 +353,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if save_method_accept_max_shard_size and max_shard_size is not None:
# max_shard_size is expected to not be None in ModelMixin
save_kwargs["max_shard_size"] = max_shard_size
if save_method_accept_flashpack:
save_kwargs["use_flashpack"] = use_flashpack
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
@@ -707,6 +713,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
flashpack`.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
@@ -772,6 +783,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant = kwargs.pop("variant", None)
dduf_file = kwargs.pop("dduf_file", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
@@ -1061,6 +1073,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
use_safetensors=use_safetensors,
use_flashpack=use_flashpack,
dduf_entries=dduf_entries,
provider_options=provider_options,
disable_mmap=disable_mmap,