mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
added fal-flashpack support
This commit is contained in:
@@ -675,6 +675,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -707,6 +708,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model in FlashPack format. FlashPack is a binary format that allows for faster
|
||||
loading. Requires the `flashpack` library to be installed.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -734,6 +738,44 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
if use_flashpack:
|
||||
if not is_main_process:
|
||||
return
|
||||
|
||||
try:
|
||||
from flashpack import pack_to_file
|
||||
import json # Ensure json is imported
|
||||
except ImportError:
|
||||
raise ImportError("The `use_flashpack=True` argument requires the `flashpack` library.")
|
||||
|
||||
flashpack_weights_name = _add_variant("model.flashpack", variant)
|
||||
save_path = os.path.join(save_directory, flashpack_weights_name)
|
||||
# Define the config path - this is what your benchmark script is looking for
|
||||
config_save_path = os.path.join(save_directory, "flashpack_config.json")
|
||||
|
||||
try:
|
||||
target_dtype = getattr(self, "dtype", None)
|
||||
logger.warning(f"Dtype used: {target_dtype}")
|
||||
# 1. Save the binary weights
|
||||
pack_to_file(self, save_path, target_dtype=target_dtype)
|
||||
|
||||
# 2. Save the metadata config
|
||||
if hasattr(self, "config"):
|
||||
try:
|
||||
# Attempt to get dictionary representation
|
||||
if hasattr(self.config, "to_dict"):
|
||||
config_data = self.config.to_dict()
|
||||
else:
|
||||
config_data = dict(self.config)
|
||||
|
||||
with open(config_save_path, "w") as f:
|
||||
json.dump(config_data, f, indent=4)
|
||||
except Exception as config_err:
|
||||
logger.warning(f"Weights saved but config serialization failed: {config_err}")
|
||||
|
||||
logger.info(f"Model weights saved in FlashPack format at {save_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save weights in FlashPack format: {e}")
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
@@ -939,6 +981,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
||||
weights. If set to `False`, `safetensors` weights are not loaded.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
|
||||
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
|
||||
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
|
||||
flashpack`.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
@@ -982,6 +1029,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
@@ -1200,6 +1248,69 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
|
||||
else:
|
||||
# If we are using `use_flashpack`, we try to load the model from flashpack first.
|
||||
# If flashpack is not available or the file cannot be loaded, we fall back to
|
||||
# the standard loading path (e.g. safetensors or PyTorch).
|
||||
if use_flashpack:
|
||||
try:
|
||||
from flashpack import assign_from_file
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
flashpack_weights_name = _add_variant("model.flashpack", variant)
|
||||
try:
|
||||
flashpack_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=flashpack_weights_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
except EnvironmentError:
|
||||
pass
|
||||
else:
|
||||
dtype_orig = None
|
||||
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"{torch_dtype} needs to be a `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
||||
)
|
||||
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
|
||||
|
||||
with no_init_weights():
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
try:
|
||||
assign_from_file(model, flashpack_file)
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
|
||||
if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None):
|
||||
model = model.to(torch_dtype)
|
||||
|
||||
model.eval()
|
||||
|
||||
if output_loading_info:
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded:
|
||||
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
|
||||
|
||||
@@ -756,6 +756,7 @@ def load_sub_model(
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
use_safetensors: bool,
|
||||
use_flashpack: bool,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]],
|
||||
provider_options: Any,
|
||||
disable_mmap: bool,
|
||||
@@ -838,6 +839,9 @@ def load_sub_model(
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
loading_kwargs["use_safetensors"] = use_safetensors
|
||||
|
||||
if is_diffusers_model:
|
||||
loading_kwargs["use_flashpack"] = use_flashpack
|
||||
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
|
||||
@@ -243,6 +243,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: Optional[Union[int, str]] = None,
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -268,7 +269,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip install
|
||||
flashpack`.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -340,6 +343,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
||||
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
|
||||
|
||||
save_kwargs = {}
|
||||
if save_method_accept_safe:
|
||||
@@ -349,6 +353,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if save_method_accept_max_shard_size and max_shard_size is not None:
|
||||
# max_shard_size is expected to not be None in ModelMixin
|
||||
save_kwargs["max_shard_size"] = max_shard_size
|
||||
if save_method_accept_flashpack:
|
||||
save_kwargs["use_flashpack"] = use_flashpack
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
@@ -707,6 +713,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
|
||||
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
|
||||
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
|
||||
flashpack`.
|
||||
use_onnx (`bool`, *optional*, defaults to `None`):
|
||||
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
@@ -772,6 +783,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant = kwargs.pop("variant", None)
|
||||
dduf_file = kwargs.pop("dduf_file", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
@@ -1061,6 +1073,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
use_safetensors=use_safetensors,
|
||||
use_flashpack=use_flashpack,
|
||||
dduf_entries=dduf_entries,
|
||||
provider_options=provider_options,
|
||||
disable_mmap=disable_mmap,
|
||||
|
||||
Reference in New Issue
Block a user