diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 63e50af617..a4bfa47f6d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -675,6 +675,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): variant: Optional[str] = None, max_shard_size: Union[int, str] = "10GB", push_to_hub: bool = False, + use_flashpack: bool = False, **kwargs, ): """ @@ -707,6 +708,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). + use_flashpack (`bool`, *optional*, defaults to `False`): + Whether to save the model in FlashPack format. FlashPack is a binary format that allows for faster + loading. Requires the `flashpack` library to be installed. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -734,6 +738,44 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ) os.makedirs(save_directory, exist_ok=True) + if use_flashpack: + if not is_main_process: + return + + try: + from flashpack import pack_to_file + import json # Ensure json is imported + except ImportError: + raise ImportError("The `use_flashpack=True` argument requires the `flashpack` library.") + + flashpack_weights_name = _add_variant("model.flashpack", variant) + save_path = os.path.join(save_directory, flashpack_weights_name) + # Define the config path - this is what your benchmark script is looking for + config_save_path = os.path.join(save_directory, "flashpack_config.json") + + try: + target_dtype = getattr(self, "dtype", None) + logger.warning(f"Dtype used: {target_dtype}") + # 1. Save the binary weights + pack_to_file(self, save_path, target_dtype=target_dtype) + + # 2. Save the metadata config + if hasattr(self, "config"): + try: + # Attempt to get dictionary representation + if hasattr(self.config, "to_dict"): + config_data = self.config.to_dict() + else: + config_data = dict(self.config) + + with open(config_save_path, "w") as f: + json.dump(config_data, f, indent=4) + except Exception as config_err: + logger.warning(f"Weights saved but config serialization failed: {config_err}") + + logger.info(f"Model weights saved in FlashPack format at {save_path}") + except Exception as e: + logger.error(f"Failed to save weights in FlashPack format: {e}") if push_to_hub: commit_message = kwargs.pop("commit_message", None) @@ -939,6 +981,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. + use_flashpack (`bool`, *optional*, defaults to `False`): + If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file + is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to + the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install + flashpack`. disable_mmap ('bool', *optional*, defaults to 'False'): Whether to disable mmap when loading a Safetensors model. This option can perform better when the model is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -982,6 +1029,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) + use_flashpack = kwargs.pop("use_flashpack", False) quantization_config = kwargs.pop("quantization_config", None) dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) @@ -1200,6 +1248,69 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file) else: + # If we are using `use_flashpack`, we try to load the model from flashpack first. + # If flashpack is not available or the file cannot be loaded, we fall back to + # the standard loading path (e.g. safetensors or PyTorch). + if use_flashpack: + try: + from flashpack import assign_from_file + except ImportError: + pass + else: + flashpack_weights_name = _add_variant("model.flashpack", variant) + try: + flashpack_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=flashpack_weights_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except EnvironmentError: + pass + else: + dtype_orig = None + if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None): + if not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be a `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + with no_init_weights(): + model = cls.from_config(config, **unused_kwargs) + + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + try: + assign_from_file(model, flashpack_file) + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + if torch_dtype is not None and torch_dtype != getattr(torch, "float8_e4m3fn", None): + model = model.to(torch_dtype) + + model.eval() + + if output_loading_info: + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + return model, loading_info + + return model + + except Exception: + pass # in the case it is sharded, we have already the index if is_sharded: resolved_model_file, sharded_metadata = _get_checkpoint_shard_files( diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 57d4eaa8f8..2fbb1028f9 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -756,6 +756,7 @@ def load_sub_model( low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], use_safetensors: bool, + use_flashpack: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, disable_mmap: bool, @@ -838,6 +839,9 @@ def load_sub_model( loading_kwargs["variant"] = model_variants.pop(name, None) loading_kwargs["use_safetensors"] = use_safetensors + if is_diffusers_model: + loading_kwargs["use_flashpack"] = use_flashpack + if from_flax: loading_kwargs["from_flax"] = True diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index b96305c741..34e42f4286 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -243,6 +243,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): variant: Optional[str] = None, max_shard_size: Optional[Union[int, str]] = None, push_to_hub: bool = False, + use_flashpack: bool = False, **kwargs, ): """ @@ -268,7 +269,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). - + use_flashpack (`bool`, *optional*, defaults to `False`): + Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip install + flashpack`. kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -340,6 +343,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): save_method_accept_safe = "safe_serialization" in save_method_signature.parameters save_method_accept_variant = "variant" in save_method_signature.parameters save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters + save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters save_kwargs = {} if save_method_accept_safe: @@ -349,6 +353,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): if save_method_accept_max_shard_size and max_shard_size is not None: # max_shard_size is expected to not be None in ModelMixin save_kwargs["max_shard_size"] = max_shard_size + if save_method_accept_flashpack: + save_kwargs["use_flashpack"] = use_flashpack save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) @@ -707,6 +713,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. + use_flashpack (`bool`, *optional*, defaults to `False`): + If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file + is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to + the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install + flashpack`. use_onnx (`bool`, *optional*, defaults to `None`): If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is @@ -772,6 +783,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): variant = kwargs.pop("variant", None) dduf_file = kwargs.pop("dduf_file", None) use_safetensors = kwargs.pop("use_safetensors", None) + use_flashpack = kwargs.pop("use_flashpack", False) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) quantization_config = kwargs.pop("quantization_config", None) @@ -1061,6 +1073,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, use_safetensors=use_safetensors, + use_flashpack=use_flashpack, dduf_entries=dduf_entries, provider_options=provider_options, disable_mmap=disable_mmap,