From fbff43acc9f52aec18e27806cc258a592f8b53f6 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Tue, 14 Jan 2025 08:51:42 +0100 Subject: [PATCH] [FEAT] DDUF format (#10037) * load and save dduf archive * style * switch to zip uncompressed * updates * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul * first draft * remove print * switch to dduf_file for consistency * switch to huggingface hub api * fix log * add a basic test * Update src/diffusers/configuration_utils.py Co-authored-by: Sayak Paul * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul * fix * fix variant * change saving logic * DDUF - Load transformers components manually (#10171) * update hfh version * Load transformers components manually * load encoder from_pretrained with state_dict * working version with transformers and tokenizer ! * add generation_config case * fix tests * remove saving for now * typing * need next version from transformers * Update src/diffusers/configuration_utils.py Co-authored-by: Lucain * check path corectly * Apply suggestions from code review Co-authored-by: Lucain * udapte * typing * remove check for subfolder * quality * revert setup changes * oups * more readable condition * add loading from the hub test * add basic docs. * Apply suggestions from code review Co-authored-by: Lucain * add example * add * make functions private * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * minor. * fixes * fix * change the precdence of parameterized. * error out when custom pipeline is passed with dduf_file. * updates * fix * updates * fixes * updates * fix xfail condition. * fix xfail * fixes * sharded checkpoint compat * add test for sharded checkpoint * add suggestions * Update src/diffusers/models/model_loading_utils.py Co-authored-by: YiYi Xu * from suggestions * add class attributes to flag dduf tests * last one * fix logic * remove comment * revert changes --------- Co-authored-by: Sayak Paul Co-authored-by: Lucain Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: YiYi Xu --- .../en/using-diffusers/other-formats.md | 40 ++++++ setup.py | 2 +- src/diffusers/configuration_utils.py | 45 +++++-- src/diffusers/dependency_versions_table.py | 2 +- src/diffusers/models/model_loading_utils.py | 44 +++++-- src/diffusers/models/modeling_utils.py | 27 +++- .../pipelines/pipeline_loading_utils.py | 105 +++++++++++++-- src/diffusers/pipelines/pipeline_utils.py | 50 +++++++- .../pipelines/transformers_loading_utils.py | 121 ++++++++++++++++++ src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/hub_utils.py | 38 +++++- src/diffusers/utils/import_utils.py | 22 ++++ src/diffusers/utils/testing_utils.py | 12 ++ tests/pipelines/allegro/test_allegro.py | 33 +++++ tests/pipelines/audioldm/test_audioldm.py | 2 + tests/pipelines/audioldm2/test_audioldm2.py | 2 + .../blipdiffusion/test_blipdiffusion.py | 2 + tests/pipelines/controlnet/test_controlnet.py | 4 + .../test_controlnet_blip_diffusion.py | 2 + .../controlnet/test_controlnet_img2img.py | 2 + .../controlnet/test_controlnet_inpaint.py | 2 + .../test_controlnet_inpaint_sdxl.py | 2 + .../controlnet/test_controlnet_sdxl.py | 4 + tests/pipelines/deepfloyd_if/test_if.py | 7 + .../pipelines/deepfloyd_if/test_if_img2img.py | 7 + .../test_if_img2img_superresolution.py | 7 + .../deepfloyd_if/test_if_inpainting.py | 7 + .../test_if_inpainting_superresolution.py | 7 + .../deepfloyd_if/test_if_superresolution.py | 7 + tests/pipelines/i2vgen_xl/test_i2vgenxl.py | 2 + tests/pipelines/kandinsky/test_kandinsky.py | 2 + .../kandinsky/test_kandinsky_combined.py | 6 + .../kandinsky/test_kandinsky_img2img.py | 2 + .../kandinsky/test_kandinsky_inpaint.py | 2 + .../kandinsky/test_kandinsky_prior.py | 2 + .../kandinsky2_2/test_kandinsky_combined.py | 6 + .../kandinsky2_2/test_kandinsky_prior.py | 2 + .../test_kandinsky_prior_emb2emb.py | 2 + tests/pipelines/kolors/test_kolors.py | 2 + tests/pipelines/kolors/test_kolors_img2img.py | 2 + tests/pipelines/lumina/test_lumina_nextdit.py | 2 + tests/pipelines/musicldm/test_musicldm.py | 2 + tests/pipelines/pag/test_pag_kolors.py | 2 + tests/pipelines/pag/test_pag_sana.py | 2 + tests/pipelines/pag/test_pag_sdxl_img2img.py | 2 + tests/pipelines/pag/test_pag_sdxl_inpaint.py | 2 + .../paint_by_example/test_paint_by_example.py | 2 + tests/pipelines/shap_e/test_shap_e_img2img.py | 2 + .../stable_audio/test_stable_audio.py | 1 + .../test_stable_diffusion_depth.py | 2 + .../test_stable_diffusion_adapter.py | 2 + ...test_stable_diffusion_gligen_text_image.py | 2 + .../test_stable_diffusion_image_variation.py | 2 + .../test_stable_diffusion_xl_adapter.py | 2 + .../test_stable_diffusion_xl_img2img.py | 2 + .../test_stable_diffusion_xl_inpaint.py | 2 + .../test_stable_unclip_img2img.py | 2 + .../test_stable_video_diffusion.py | 2 + tests/pipelines/test_pipelines.py | 84 ++++++++++++ tests/pipelines/test_pipelines_common.py | 37 ++++++ .../unclip/test_unclip_image_variation.py | 1 + .../pipelines/unidiffuser/test_unidiffuser.py | 2 + 62 files changed, 750 insertions(+), 45 deletions(-) create mode 100644 src/diffusers/pipelines/transformers_loading_utils.py diff --git a/docs/source/en/using-diffusers/other-formats.md b/docs/source/en/using-diffusers/other-formats.md index 24ac9ced84..e662e3940a 100644 --- a/docs/source/en/using-diffusers/other-formats.md +++ b/docs/source/en/using-diffusers/other-formats.md @@ -240,6 +240,46 @@ Benefits of using a single-file layout include: 1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout. 2. Easier to manage (download and share) a single file. +### DDUF + +> [!WARNING] +> DDUF is an experimental file format and APIs related to it can change in the future. + +DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format. + +Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf). + +Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`]. + +```py +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained( + "DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16 +).to("cuda") +image = pipe( + "photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5 +).images[0] +image.save("cat.png") +``` + +To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations. + +```py +from huggingface_hub import export_folder_as_dduf +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + +save_folder = "flux-dev" +pipe.save_pretrained("flux-dev") +export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder) + +> [!TIP] +> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure. + ## Convert layout and files Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem. diff --git a/setup.py b/setup.py index d696c14ca8..0acdcbbb9c 100644 --- a/setup.py +++ b/setup.py @@ -101,7 +101,7 @@ _deps = [ "filelock", "flax>=0.4.1", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.23.2", + "huggingface-hub>=0.27.0", "requests-mock==1.10.0", "importlib_metadata", "invisible-watermark>=0.2.0", diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index d21ada6fbe..9dd4f0121a 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -24,10 +24,10 @@ import os import re from collections import OrderedDict from pathlib import Path -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np -from huggingface_hub import create_repo, hf_hub_download +from huggingface_hub import DDUFEntry, create_repo, hf_hub_download from huggingface_hub.utils import ( EntryNotFoundError, RepositoryNotFoundError, @@ -347,6 +347,7 @@ class ConfigMixin: _ = kwargs.pop("mirror", None) subfolder = kwargs.pop("subfolder", None) user_agent = kwargs.pop("user_agent", {}) + dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) user_agent = {**user_agent, "file_type": "config"} user_agent = http_user_agent(user_agent) @@ -358,8 +359,15 @@ class ConfigMixin: "`self.config_name` is not defined. Note that one should not load a config from " "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" ) - - if os.path.isfile(pretrained_model_name_or_path): + # Custom path for now + if dduf_entries: + if subfolder is not None: + raise ValueError( + "DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). " + "Please check the DDUF structure" + ) + config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries) + elif os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): if subfolder is not None and os.path.isfile( @@ -426,10 +434,8 @@ class ConfigMixin: f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a {cls.config_name} file" ) - try: - # Load config dict - config_dict = cls._dict_from_json_file(config_file) + config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries) commit_hash = extract_commit_hash(config_file) except (json.JSONDecodeError, UnicodeDecodeError): @@ -552,9 +558,14 @@ class ConfigMixin: return init_dict, unused_kwargs, hidden_config_dict @classmethod - def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): - with open(json_file, "r", encoding="utf-8") as reader: - text = reader.read() + def _dict_from_json_file( + cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None + ): + if dduf_entries: + text = dduf_entries[json_file].read_text() + else: + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() return json.loads(text) def __repr__(self): @@ -616,6 +627,20 @@ class ConfigMixin: with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(self.to_json_string()) + @classmethod + def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]): + # paths inside a DDUF file must always be "/" + config_file = ( + cls.config_name + if pretrained_model_name_or_path == "" + else "/".join([pretrained_model_name_or_path, cls.config_name]) + ) + if config_file not in dduf_entries: + raise ValueError( + f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}" + ) + return config_file + def register_to_config(init): r""" diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index bb5a54f734..7999368f14 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -9,7 +9,7 @@ deps = { "filelock": "filelock", "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.23.2", + "huggingface-hub": "huggingface-hub>=0.27.0", "requests-mock": "requests-mock==1.10.0", "importlib_metadata": "importlib_metadata", "invisible-watermark": "invisible-watermark>=0.2.0", diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index a3d006f189..386c07e874 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -20,10 +20,11 @@ import os from array import array from collections import OrderedDict from pathlib import Path -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import safetensors import torch +from huggingface_hub import DDUFEntry from huggingface_hub.utils import EntryNotFoundError from ..utils import ( @@ -132,7 +133,10 @@ def _fetch_remapped_cls_from_config(config, old_class): def load_state_dict( - checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False + checkpoint_file: Union[str, os.PathLike], + variant: Optional[str] = None, + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, + disable_mmap: bool = False, ): """ Reads a checkpoint file, returning properly formatted errors if they arise. @@ -144,6 +148,10 @@ def load_state_dict( try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: + if dduf_entries: + # tensors are loaded on cpu + with dduf_entries[checkpoint_file].as_mmap() as mm: + return safetensors.torch.load(mm) if disable_mmap: return safetensors.torch.load(open(checkpoint_file, "rb").read()) else: @@ -284,6 +292,7 @@ def _fetch_index_file( revision, user_agent, commit_hash, + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, ): if is_local: index_file = Path( @@ -309,8 +318,10 @@ def _fetch_index_file( subfolder=None, user_agent=user_agent, commit_hash=commit_hash, + dduf_entries=dduf_entries, ) - index_file = Path(index_file) + if not dduf_entries: + index_file = Path(index_file) except (EntryNotFoundError, EnvironmentError): index_file = None @@ -319,7 +330,9 @@ def _fetch_index_file( # Adapted from # https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 -def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): +def _merge_sharded_checkpoints( + sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None +): weight_map = sharded_metadata.get("weight_map", None) if weight_map is None: raise KeyError("'weight_map' key not found in the shard index file.") @@ -332,14 +345,23 @@ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): # Load tensors from each unique file for file_name in files_to_load: part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) - if not os.path.exists(part_file_path): - raise FileNotFoundError(f"Part file {file_name} not found.") + if dduf_entries: + if part_file_path not in dduf_entries: + raise FileNotFoundError(f"Part file {file_name} not found.") + else: + if not os.path.exists(part_file_path): + raise FileNotFoundError(f"Part file {file_name} not found.") if is_safetensors: - with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: - for tensor_key in f.keys(): - if tensor_key in weight_map: - merged_state_dict[tensor_key] = f.get_tensor(tensor_key) + if dduf_entries: + with dduf_entries[part_file_path].as_mmap() as mm: + tensors = safetensors.torch.load(mm) + merged_state_dict.update(tensors) + else: + with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: + for tensor_key in f.keys(): + if tensor_key in weight_map: + merged_state_dict[tensor_key] = f.get_tensor(tensor_key) else: merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) @@ -360,6 +382,7 @@ def _fetch_index_file_legacy( revision, user_agent, commit_hash, + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, ): if is_local: index_file = Path( @@ -400,6 +423,7 @@ def _fetch_index_file_legacy( subfolder=None, user_agent=user_agent, commit_hash=commit_hash, + dduf_entries=dduf_entries, ) index_file = Path(index_file) deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 17e9d20431..fcd7775fb6 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -23,11 +23,11 @@ import re from collections import OrderedDict from functools import partial, wraps from pathlib import Path -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import safetensors import torch -from huggingface_hub import create_repo, split_torch_state_dict_into_shards +from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards from huggingface_hub.utils import validate_hf_hub_args from torch import Tensor, nn @@ -607,6 +607,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) quantization_config = kwargs.pop("quantization_config", None) + dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) allow_pickle = False @@ -700,6 +701,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): revision=revision, subfolder=subfolder, user_agent=user_agent, + dduf_entries=dduf_entries, **kwargs, ) # no in-place modification of the original config. @@ -776,13 +778,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): "revision": revision, "user_agent": user_agent, "commit_hash": commit_hash, + "dduf_entries": dduf_entries, } index_file = _fetch_index_file(**index_file_kwargs) # In case the index file was not found we still have to consider the legacy format. # this becomes applicable when the variant is not None. if variant is not None and (index_file is None or not os.path.exists(index_file)): index_file = _fetch_index_file_legacy(**index_file_kwargs) - if index_file is not None and index_file.is_file(): + if index_file is not None and (dduf_entries or index_file.is_file()): is_sharded = True if is_sharded and from_flax: @@ -811,6 +814,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): model = load_flax_checkpoint_in_pytorch_model(model, model_file) else: + # in the case it is sharded, we have already the index if is_sharded: sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( pretrained_model_name_or_path, @@ -822,10 +826,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): user_agent=user_agent, revision=revision, subfolder=subfolder or "", + dduf_entries=dduf_entries, ) # TODO: https://github.com/huggingface/diffusers/issues/10013 - if hf_quantizer is not None: - model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) + if hf_quantizer is not None or dduf_entries: + model_file = _merge_sharded_checkpoints( + sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries + ) logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") is_sharded = False @@ -843,6 +850,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, + dduf_entries=dduf_entries, ) except IOError as e: @@ -866,6 +874,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, + dduf_entries=dduf_entries, ) if low_cpu_mem_usage: @@ -887,7 +896,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): # TODO (sayakpaul, SunMarc): remove this after model loading refactor else: param_device = torch.device(torch.cuda.current_device()) - state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) + state_dict = load_state_dict( + model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap + ) model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu @@ -983,7 +994,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): else: model = cls.from_config(config, **unused_kwargs) - state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) + state_dict = load_state_dict( + model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap + ) model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 23f1279e20..a100dfe77b 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -12,19 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - import importlib import os import re import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union +import requests import torch -from huggingface_hub import ModelCard, model_info -from huggingface_hub.utils import validate_hf_hub_args +from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download +from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args from packaging import version +from requests.exceptions import HTTPError from .. import __version__ from ..utils import ( @@ -38,14 +38,16 @@ from ..utils import ( is_accelerate_available, is_peft_available, is_transformers_available, + is_transformers_version, logging, ) from ..utils.torch_utils import is_compiled_module +from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transformers_model_from_dduf if is_transformers_available(): import transformers - from transformers import PreTrainedModel + from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME @@ -627,6 +629,7 @@ def load_sub_model( low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], use_safetensors: bool, + dduf_entries: Optional[Dict[str, DDUFEntry]], ): """Helper method to load the module `name` from `library_name` and `class_name`""" @@ -663,7 +666,7 @@ def load_sub_model( f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." ) - load_method = getattr(class_obj, load_method_name) + load_method = _get_load_method(class_obj, load_method_name, is_dduf=dduf_entries is not None) # add kwargs to loading method diffusers_module = importlib.import_module(__name__.split(".")[0]) @@ -721,7 +724,10 @@ def load_sub_model( loading_kwargs["low_cpu_mem_usage"] = False # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): + if dduf_entries: + loading_kwargs["dduf_entries"] = dduf_entries + loaded_sub_model = load_method(name, **loading_kwargs) + elif os.path.isdir(os.path.join(cached_folder, name)): loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) else: # else load from the root directory @@ -746,6 +752,22 @@ def load_sub_model( return loaded_sub_model +def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> Callable: + """ + Return the method to load the sub model. + + In practice, this method will return the `"from_pretrained"` (or `load_method_name`) method of the class object + except if loading from a DDUF checkpoint. In that case, transformers models and tokenizers have a specific loading + method that we need to use. + """ + if is_dduf: + if issubclass(class_obj, PreTrainedTokenizerBase): + return lambda *args, **kwargs: _load_tokenizer_from_dduf(class_obj, *args, **kwargs) + if issubclass(class_obj, PreTrainedModel): + return lambda *args, **kwargs: _load_transformers_model_from_dduf(class_obj, *args, **kwargs) + return getattr(class_obj, load_method_name) + + def _fetch_class_library_tuple(module): # import it here to avoid circular import diffusers_module = importlib.import_module(__name__.split(".")[0]) @@ -968,3 +990,70 @@ def _get_ignore_patterns( ) return ignore_patterns + + +def _download_dduf_file( + pretrained_model_name: str, + dduf_file: str, + pipeline_class_name: str, + cache_dir: str, + proxies: str, + local_files_only: bool, + token: str, + revision: str, +): + model_info_call_error = None + if not local_files_only: + try: + info = model_info(pretrained_model_name, token=token, revision=revision) + except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e: + logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") + local_files_only = True + model_info_call_error = e # save error to reraise it if model is not cached locally + + if ( + not local_files_only + and dduf_file is not None + and dduf_file not in (sibling.rfilename for sibling in info.siblings) + ): + raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.") + + try: + user_agent = {"pipeline_class": pipeline_class_name, "dduf": True} + cached_folder = snapshot_download( + pretrained_model_name, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=[dduf_file], + user_agent=user_agent, + ) + return cached_folder + except FileNotFoundError: + # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache. + # This can happen in two cases: + # 1. If the user passed `local_files_only=True` => we raise the error directly + # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error + if model_info_call_error is None: + # 1. user passed `local_files_only=True` + raise + else: + # 2. we forced `local_files_only=True` when `model_info` failed + raise EnvironmentError( + f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred" + " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace" + " above." + ) from model_info_call_error + + +def _maybe_raise_error_for_incorrect_transformers(config_dict): + has_transformers_component = False + for k in config_dict: + if isinstance(config_dict[k], list): + has_transformers_component = config_dict[k][0] == "transformers" + if has_transformers_component: + break + if has_transformers_component and not is_transformers_version(">", "4.47.1"): + raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.") diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 527724d1de..3cafb77e5d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -29,10 +29,12 @@ import PIL.Image import requests import torch from huggingface_hub import ( + DDUFEntry, ModelCard, create_repo, hf_hub_download, model_info, + read_dduf_file, snapshot_download, ) from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args @@ -72,6 +74,7 @@ from .pipeline_loading_utils import ( CONNECTED_PIPES_KEYS, CUSTOM_PIPELINE_FILE_NAME, LOADABLE_CLASSES, + _download_dduf_file, _fetch_class_library_tuple, _get_custom_components_and_folders, _get_custom_pipeline_class, @@ -79,6 +82,7 @@ from .pipeline_loading_utils import ( _get_ignore_patterns, _get_pipeline_class, _identify_model_variants, + _maybe_raise_error_for_incorrect_transformers, _maybe_raise_warning_for_inpainting, _resolve_custom_pipeline_and_cls, _unwrap_model, @@ -218,6 +222,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). + kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -531,6 +536,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. @@ -625,6 +631,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. + dduf_file(`str`, *optional*): + Load weights from the specified dduf file. @@ -674,6 +682,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) + dduf_file = kwargs.pop("dduf_file", None) use_safetensors = kwargs.pop("use_safetensors", None) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) @@ -722,6 +731,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) + if dduf_file: + if custom_pipeline: + raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.") + if load_connected_pipeline: + raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.") + # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): @@ -744,6 +759,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): custom_pipeline=custom_pipeline, custom_revision=custom_revision, variant=variant, + dduf_file=dduf_file, load_connected_pipeline=load_connected_pipeline, **kwargs, ) @@ -765,7 +781,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ) logger.warning(warn_msg) - config_dict = cls.load_config(cached_folder) + dduf_entries = None + if dduf_file: + dduf_file_path = os.path.join(cached_folder, dduf_file) + dduf_entries = read_dduf_file(dduf_file_path) + # The reader contains already all the files needed, no need to check it again + cached_folder = "" + + config_dict = cls.load_config(cached_folder, dduf_entries=dduf_entries) + + if dduf_file: + _maybe_raise_error_for_incorrect_transformers(config_dict) # pop out "_ignore_files" as it is only needed for download config_dict.pop("_ignore_files", None) @@ -943,6 +969,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, use_safetensors=use_safetensors, + dduf_entries=dduf_entries, ) logger.info( f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." @@ -1256,6 +1283,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. + dduf_file(`str`, *optional*): + Load weights from the specified DDUF file. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors @@ -1296,6 +1325,23 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) trust_remote_code = kwargs.pop("trust_remote_code", False) + dduf_file: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_file", None) + + if dduf_file: + if custom_pipeline: + raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.") + if load_connected_pipeline: + raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.") + return _download_dduf_file( + pretrained_model_name=pretrained_model_name, + dduf_file=dduf_file, + pipeline_class_name=cls.__name__, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + ) allow_pickle = False if use_safetensors is None: @@ -1375,7 +1421,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] # also allow downloading config.json files with the model allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] - allow_patterns += [ SCHEDULER_CONFIG_NAME, CONFIG_NAME, @@ -1471,7 +1516,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): user_agent=user_agent, ) - # retrieve pipeline class from local file cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None) cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name diff --git a/src/diffusers/pipelines/transformers_loading_utils.py b/src/diffusers/pipelines/transformers_loading_utils.py new file mode 100644 index 0000000000..f080adb23d --- /dev/null +++ b/src/diffusers/pipelines/transformers_loading_utils.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import os +import tempfile +from typing import TYPE_CHECKING, Dict + +from huggingface_hub import DDUFEntry +from tqdm import tqdm + +from ..utils import is_safetensors_available, is_transformers_available, is_transformers_version + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + +if is_transformers_available(): + from transformers import PreTrainedModel, PreTrainedTokenizer + +if is_safetensors_available(): + import safetensors.torch + + +def _load_tokenizer_from_dduf( + cls: "PreTrainedTokenizer", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs +) -> "PreTrainedTokenizer": + """ + Load a tokenizer from a DDUF archive. + + In practice, `transformers` do not provide a way to load a tokenizer from a DDUF archive. This function is a + workaround by extracting the tokenizer files from the DDUF archive and loading the tokenizer from the extracted + files. There is an extra cost of extracting the files, but of limited impact as the tokenizer files are usually + small-ish. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + for entry_name, entry in dduf_entries.items(): + if entry_name.startswith(name + "/"): + tmp_entry_path = os.path.join(tmp_dir, *entry_name.split("/")) + # need to create intermediary directory if they don't exist + os.makedirs(os.path.dirname(tmp_entry_path), exist_ok=True) + with open(tmp_entry_path, "wb") as f: + with entry.as_mmap() as mm: + f.write(mm) + return cls.from_pretrained(os.path.dirname(tmp_entry_path), **kwargs) + + +def _load_transformers_model_from_dduf( + cls: "PreTrainedModel", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs +) -> "PreTrainedModel": + """ + Load a transformers model from a DDUF archive. + + In practice, `transformers` do not provide a way to load a model from a DDUF archive. This function is a workaround + by instantiating a model from the config file and loading the weights from the DDUF archive directly. + """ + config_file = dduf_entries.get(f"{name}/config.json") + if config_file is None: + raise EnvironmentError( + f"Could not find a config.json file for component {name} in DDUF file (contains {dduf_entries.keys()})." + ) + generation_config = dduf_entries.get(f"{name}/generation_config.json", None) + + weight_files = [ + entry + for entry_name, entry in dduf_entries.items() + if entry_name.startswith(f"{name}/") and entry_name.endswith(".safetensors") + ] + if not weight_files: + raise EnvironmentError( + f"Could not find any weight file for component {name} in DDUF file (contains {dduf_entries.keys()})." + ) + if not is_safetensors_available(): + raise EnvironmentError( + "Safetensors is not available, cannot load model from DDUF. Please `pip install safetensors`." + ) + if is_transformers_version("<", "4.47.0"): + raise ImportError( + "You need to install `transformers>4.47.0` in order to load a transformers model from a DDUF file. " + "You can install it with: `pip install --upgrade transformers`" + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + from transformers import AutoConfig, GenerationConfig + + tmp_config_file = os.path.join(tmp_dir, "config.json") + with open(tmp_config_file, "w") as f: + f.write(config_file.read_text()) + config = AutoConfig.from_pretrained(tmp_config_file) + if generation_config is not None: + tmp_generation_config_file = os.path.join(tmp_dir, "generation_config.json") + with open(tmp_generation_config_file, "w") as f: + f.write(generation_config.read_text()) + generation_config = GenerationConfig.from_pretrained(tmp_generation_config_file) + state_dict = {} + with contextlib.ExitStack() as stack: + for entry in tqdm(weight_files, desc="Loading state_dict"): # Loop over safetensors files + # Memory-map the safetensors file + mmap = stack.enter_context(entry.as_mmap()) + # Load tensors from the memory-mapped file + tensors = safetensors.torch.load(mmap) + # Update the state dictionary with tensors + state_dict.update(tensors) + return cls.from_pretrained( + pretrained_model_name_or_path=None, + config=config, + generation_config=generation_config, + state_dict=state_dict, + **kwargs, + ) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index f8de48ecfc..5a171d078c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -70,6 +70,7 @@ from .import_utils import ( is_gguf_available, is_gguf_version, is_google_colab, + is_hf_hub_version, is_inflect_available, is_invisible_watermark_available, is_k_diffusion_available, diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index a6dfe18433..839e696c0c 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -26,6 +26,7 @@ from typing import Dict, List, Optional, Union from uuid import uuid4 from huggingface_hub import ( + DDUFEntry, ModelCard, ModelCardData, create_repo, @@ -291,9 +292,26 @@ def _get_model_file( user_agent: Optional[Union[Dict, str]] = None, revision: Optional[str] = None, commit_hash: Optional[str] = None, + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, ): pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isfile(pretrained_model_name_or_path): + + if dduf_entries: + if subfolder is not None: + raise ValueError( + "DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). " + "Please check the DDUF structure" + ) + model_file = ( + weights_name + if pretrained_model_name_or_path == "" + else "/".join([pretrained_model_name_or_path, weights_name]) + ) + if model_file in dduf_entries: + return model_file + else: + raise EnvironmentError(f"Error no file named {weights_name} found in archive {dduf_entries.keys()}.") + elif os.path.isfile(pretrained_model_name_or_path): return pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): @@ -419,6 +437,7 @@ def _get_checkpoint_shard_files( user_agent=None, revision=None, subfolder="", + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, ): """ For a given model: @@ -430,11 +449,18 @@ def _get_checkpoint_shard_files( For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). """ - if not os.path.isfile(index_filename): - raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") + if dduf_entries: + if index_filename not in dduf_entries: + raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") + else: + if not os.path.isfile(index_filename): + raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") - with open(index_filename, "r") as f: - index = json.loads(f.read()) + if dduf_entries: + index = json.loads(dduf_entries[index_filename].read_text()) + else: + with open(index_filename, "r") as f: + index = json.loads(f.read()) original_shard_filenames = sorted(set(index["weight_map"].values())) sharded_metadata = index["metadata"] @@ -448,6 +474,8 @@ def _get_checkpoint_shard_files( pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames ) return shards_path, sharded_metadata + elif dduf_entries: + return shards_path, sharded_metadata # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 3014efebc8..c7d002651f 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -115,6 +115,13 @@ try: except importlib_metadata.PackageNotFoundError: _transformers_available = False +_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None +try: + _hf_hub_version = importlib_metadata.version("huggingface_hub") + logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}") +except importlib_metadata.PackageNotFoundError: + _hf_hub_available = False + _inflect_available = importlib.util.find_spec("inflect") is not None try: @@ -767,6 +774,21 @@ def is_transformers_version(operation: str, version: str): return compare_versions(parse(_transformers_version), operation, version) +def is_hf_hub_version(operation: str, version: str): + """ + Compares the current Hugging Face Hub version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _hf_hub_available: + return False + return compare_versions(parse(_hf_hub_version), operation, version) + + def is_accelerate_version(operation: str, version: str): """ Compares the current Accelerate version to a given reference with an operation. diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 3ae74cddcb..62156786c6 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -478,6 +478,18 @@ def require_bitsandbytes_version_greater(bnb_version): return decorator +def require_hf_hub_version_greater(hf_hub_version): + def decorator(test_case): + correct_hf_hub_version = version.parse( + version.parse(importlib.metadata.version("huggingface_hub")).base_version + ) > version.parse(hf_hub_version) + return unittest.skipUnless( + correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}." + )(test_case) + + return decorator + + def require_gguf_version_greater_or_equal(gguf_version): def decorator(test_case): correct_gguf_version = is_gguf_available() and version.parse( diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index d09fc04883..6ca96b19b8 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -14,6 +14,8 @@ import gc import inspect +import os +import tempfile import unittest import numpy as np @@ -24,7 +26,9 @@ from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLA from diffusers.utils.testing_utils import ( enable_full_determinism, numpy_cosine_similarity_distance, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, slow, torch_device, ) @@ -297,6 +301,35 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "VAE tiling should not affect the inference results", ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + # reimplement because it needs `enable_tiling()` on the loaded pipe. + from huggingface_hub import export_folder_as_dduf + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device="cpu") + inputs.pop("generator") + inputs["generator"] = torch.manual_seed(0) + + pipeline_out = pipe(**inputs)[0].cpu() + + with tempfile.TemporaryDirectory() as tmpdir: + dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf") + pipe.save_pretrained(tmpdir, safe_serialization=True) + export_folder_as_dduf(dduf_filename, folder_path=tmpdir) + loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device) + + loaded_pipe.vae.enable_tiling() + inputs["generator"] = torch.manual_seed(0) + loaded_pipeline_out = loaded_pipe(**inputs)[0].cpu() + + assert np.allclose(pipeline_out, loaded_pipeline_out) + @slow @require_torch_gpu diff --git a/tests/pipelines/audioldm/test_audioldm.py b/tests/pipelines/audioldm/test_audioldm.py index eddab54a3c..aaf44985aa 100644 --- a/tests/pipelines/audioldm/test_audioldm.py +++ b/tests/pipelines/audioldm/test_audioldm.py @@ -63,6 +63,8 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index bf3ce2542d..95aaa370ef 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -70,6 +70,8 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = AudioLDM2UNet2DConditionModel( diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py index 7e85cef651..6d422745ce 100644 --- a/tests/pipelines/blipdiffusion/test_blipdiffusion.py +++ b/tests/pipelines/blipdiffusion/test_blipdiffusion.py @@ -60,6 +60,8 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "prompt_reps", ] + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) text_encoder_config = CLIPTextConfig( diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index b12655d989..fc8ea5284c 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -291,6 +291,8 @@ class StableDiffusionMultiControlNetPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( @@ -523,6 +525,8 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py index 99a238caf5..b4d3e3aaa8 100644 --- a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py +++ b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py @@ -68,6 +68,8 @@ class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.Tes "prompt_reps", ] + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) text_encoder_config = CLIPTextConfig( diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 7c4ae716b3..516fcc513b 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -198,6 +198,8 @@ class StableDiffusionMultiControlNetPipelineFastTests( batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index e49106334c..0e4dba4265 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -257,6 +257,8 @@ class MultiControlNetInpaintPipelineFastTests( params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py index d2c63137c9..6e752804e2 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py @@ -78,6 +78,8 @@ class ControlNetPipelineSDXLFastTests( } ) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index ea7fff5537..fc15973fae 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -487,6 +487,8 @@ class StableDiffusionXLMultiControlNetPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( @@ -692,6 +694,8 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/deepfloyd_if/test_if.py b/tests/pipelines/deepfloyd_if/test_if.py index 13a05855f1..2231821fbc 100644 --- a/tests/pipelines/deepfloyd_if/test_if.py +++ b/tests/pipelines/deepfloyd_if/test_if.py @@ -26,7 +26,9 @@ from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -89,6 +91,11 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T def test_xformers_attention_forwardGenerator_pass(self): self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img.py b/tests/pipelines/deepfloyd_if/test_if_img2img.py index 26ac42831b..c6d5384e24 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img.py @@ -26,7 +26,9 @@ from diffusers.utils.testing_utils import ( floats_tensor, load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -100,6 +102,11 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni expected_max_diff=1e-2, ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py index 1d1244c96c..7cdd8cd147 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py @@ -26,7 +26,9 @@ from diffusers.utils.testing_utils import ( floats_tensor, load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -97,6 +99,11 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT expected_max_diff=1e-2, ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting.py b/tests/pipelines/deepfloyd_if/test_if_inpainting.py index 1c4f274033..9f15119025 100644 --- a/tests/pipelines/deepfloyd_if/test_if_inpainting.py +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting.py @@ -26,7 +26,9 @@ from diffusers.utils.testing_utils import ( floats_tensor, load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -97,6 +99,11 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, expected_max_diff=1e-2, ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py index fc1b04aacb..c2b48bfd6d 100644 --- a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py @@ -26,7 +26,9 @@ from diffusers.utils.testing_utils import ( floats_tensor, load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -99,6 +101,11 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli expected_max_diff=1e-2, ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_superresolution.py index bdb9f8a76d..57e12899e4 100644 --- a/tests/pipelines/deepfloyd_if/test_if_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_superresolution.py @@ -26,7 +26,9 @@ from diffusers.utils.testing_utils import ( floats_tensor, load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -92,6 +94,11 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi expected_max_diff=1e-2, ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py index 592ebd35f4..f4d6165f90 100644 --- a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py +++ b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py @@ -59,6 +59,8 @@ class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unit # No `output_type`. required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"]) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) scheduler = DDIMScheduler( diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py index 8553ed96e9..1a13ec75d0 100644 --- a/tests/pipelines/kandinsky/test_kandinsky.py +++ b/tests/pipelines/kandinsky/test_kandinsky.py @@ -204,6 +204,8 @@ class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummy = Dummies() return dummy.get_dummy_components() diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py index a7f861565c..3c8767a708 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py @@ -52,6 +52,8 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase) ] test_xformers_attention = True + supports_dduf = False + def get_dummy_components(self): dummy = Dummies() prior_dummy = PriorDummies() @@ -160,6 +162,8 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummy = Img2ImgDummies() prior_dummy = PriorDummies() @@ -269,6 +273,8 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummy = InpaintDummies() prior_dummy = PriorDummies() diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py index ea289c5ccd..23f13ffee2 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py +++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py @@ -226,6 +226,8 @@ class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummies = Dummies() return dummies.get_dummy_components() diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py index 7400466787..ebb1a4d887 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py @@ -220,6 +220,8 @@ class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummies = Dummies() return dummies.get_dummy_components() diff --git a/tests/pipelines/kandinsky/test_kandinsky_prior.py b/tests/pipelines/kandinsky/test_kandinsky_prior.py index 5f42447bd9..abb53bfb79 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_prior.py +++ b/tests/pipelines/kandinsky/test_kandinsky_prior.py @@ -184,6 +184,8 @@ class KandinskyPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummy = Dummies() return dummy.get_dummy_components() diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py index dbba083139..bbf2f08a7b 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py @@ -57,6 +57,8 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa test_xformers_attention = True callback_cfg_params = ["image_embds"] + supports_dduf = False + def get_dummy_components(self): dummy = Dummies() prior_dummy = PriorDummies() @@ -181,6 +183,8 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest test_xformers_attention = False callback_cfg_params = ["image_embds"] + supports_dduf = False + def get_dummy_components(self): dummy = Img2ImgDummies() prior_dummy = PriorDummies() @@ -302,6 +306,8 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummy = InpaintDummies() prior_dummy = PriorDummies() diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py index be0bc238d4..bdec6c132f 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py @@ -186,6 +186,8 @@ class KandinskyV22PriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase) callback_cfg_params = ["prompt_embeds", "text_encoder_hidden_states", "text_mask"] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummies = Dummies() return dummies.get_dummy_components() diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py index e898824e2d..0ea32981d5 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py @@ -59,6 +59,8 @@ class KandinskyV22PriorEmb2EmbPipelineFastTests(PipelineTesterMixin, unittest.Te ] test_xformers_attention = False + supports_dduf = False + @property def text_embedder_hidden_size(self): return 32 diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py index de44af6d59..e88ba02820 100644 --- a/tests/pipelines/kolors/test_kolors.py +++ b/tests/pipelines/kolors/test_kolors.py @@ -47,6 +47,8 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) + supports_dduf = False + def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/kolors/test_kolors_img2img.py b/tests/pipelines/kolors/test_kolors_img2img.py index 2010dbd705..9f1ca43a08 100644 --- a/tests/pipelines/kolors/test_kolors_img2img.py +++ b/tests/pipelines/kolors/test_kolors_img2img.py @@ -51,6 +51,8 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) + supports_dduf = False + # Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py index 5fd0dbf060..e0fd06847b 100644 --- a/tests/pipelines/lumina/test_lumina_nextdit.py +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -31,6 +31,8 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM ) batch_params = frozenset(["prompt", "negative_prompt"]) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) transformer = LuminaNextDiT2DModel( diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py index e51f510393..bdd536b6ff 100644 --- a/tests/pipelines/musicldm/test_musicldm.py +++ b/tests/pipelines/musicldm/test_musicldm.py @@ -65,6 +65,8 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/pag/test_pag_kolors.py b/tests/pipelines/pag/test_pag_kolors.py index 8cfb2c3fd1..cf9466988d 100644 --- a/tests/pipelines/pag/test_pag_kolors.py +++ b/tests/pipelines/pag/test_pag_kolors.py @@ -56,6 +56,8 @@ class KolorsPAGPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) + supports_dduf = False + # Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/pag/test_pag_sana.py b/tests/pipelines/pag/test_pag_sana.py index 12addabeb0..a2c6572978 100644 --- a/tests/pipelines/pag/test_pag_sana.py +++ b/tests/pipelines/pag/test_pag_sana.py @@ -53,6 +53,8 @@ class SanaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) transformer = SanaTransformer2DModel( diff --git a/tests/pipelines/pag/test_pag_sdxl_img2img.py b/tests/pipelines/pag/test_pag_sdxl_img2img.py index 7e5fc5fa28..33bd47bfee 100644 --- a/tests/pipelines/pag/test_pag_sdxl_img2img.py +++ b/tests/pipelines/pag/test_pag_sdxl_img2img.py @@ -82,6 +82,8 @@ class StableDiffusionXLPAGImg2ImgPipelineFastTests( {"add_text_embeds", "add_time_ids", "add_neg_time_ids"} ) + supports_dduf = False + # based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_img2img_pipeline.get_dummy_components def get_dummy_components( self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False diff --git a/tests/pipelines/pag/test_pag_sdxl_inpaint.py b/tests/pipelines/pag/test_pag_sdxl_inpaint.py index efc37abd06..8378b07e9f 100644 --- a/tests/pipelines/pag/test_pag_sdxl_inpaint.py +++ b/tests/pipelines/pag/test_pag_sdxl_inpaint.py @@ -82,6 +82,8 @@ class StableDiffusionXLPAGInpaintPipelineFastTests( {"add_text_embeds", "add_time_ids", "mask", "masked_image_latents"} ) + supports_dduf = False + # based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipelineFastTests.get_dummy_components def get_dummy_components( self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py index c71e2d4761..6b668de276 100644 --- a/tests/pipelines/paint_by_example/test_paint_by_example.py +++ b/tests/pipelines/paint_by_example/test_paint_by_example.py @@ -46,6 +46,8 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py index f3661355e9..ac7096874b 100644 --- a/tests/pipelines/shap_e/test_shap_e_img2img.py +++ b/tests/pipelines/shap_e/test_shap_e_img2img.py @@ -50,6 +50,8 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] test_xformers_attention = False + supports_dduf = False + @property def text_embedder_hidden_size(self): return 16 diff --git a/tests/pipelines/stable_audio/test_stable_audio.py b/tests/pipelines/stable_audio/test_stable_audio.py index 41ac94891c..b2ca3ddd0e 100644 --- a/tests/pipelines/stable_audio/test_stable_audio.py +++ b/tests/pipelines/stable_audio/test_stable_audio.py @@ -70,6 +70,7 @@ class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) # There is not xformers version of the StableAudioPipeline custom attention processor test_xformers_attention = False + supports_dduf = False def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 01a0a3abe4..430d99781a 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -76,6 +76,8 @@ class StableDiffusionDepth2ImgPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"depth_mask"}) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py index 2a1e691e9e..15f298c67e 100644 --- a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py +++ b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py @@ -389,6 +389,8 @@ class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterM class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): + supports_dduf = False + def get_dummy_components(self, time_cond_proj_dim=None): return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim) diff --git a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py b/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py index 748702541b..15e4c60db8 100644 --- a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py +++ b/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py @@ -66,6 +66,8 @@ class GligenTextImagePipelineFastTests( image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py index 7a3b0f70cc..d7567afdee 100644 --- a/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py @@ -58,6 +58,8 @@ class StableDiffusionImageVariationPipelineFastTests( # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess image_latents_params = frozenset([]) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 7c7b037865..23291b0407 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -422,6 +422,8 @@ class StableDiffusionXLAdapterPipelineFastTests( class StableDiffusionXLMultiAdapterPipelineFastTests( StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase ): + supports_dduf = False + def get_dummy_components(self, time_cond_proj_dim=None): return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index db0905a483..ceec86a811 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -77,6 +77,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests( {"add_text_embeds", "add_time_ids", "add_neg_time_ids"} ) + supports_dduf = False + def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 964c7123dd..c759f4c112 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -72,6 +72,8 @@ class StableDiffusionXLInpaintPipelineFastTests( } ) + supports_dduf = False + def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index a5cbf77615..34f2553a91 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -51,6 +51,8 @@ class StableUnCLIPImg2ImgPipelineFastTests( ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess image_latents_params = frozenset([]) + supports_dduf = False + def get_dummy_components(self): embedder_hidden_size = 32 embedder_projection_dim = embedder_hidden_size diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py index ac9acb26af..352477ecec 100644 --- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py +++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py @@ -58,6 +58,8 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa ] ) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNetSpatioTemporalConditionModel( diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 423c82e060..6665a005ba 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -75,9 +75,11 @@ from diffusers.utils.testing_utils import ( nightly, require_compel, require_flax, + require_hf_hub_version_greater, require_onnxruntime, require_torch_2, require_torch_gpu, + require_transformers_version_greater, run_test_in_subprocess, slow, torch_device, @@ -981,6 +983,18 @@ class DownloadTests(unittest.TestCase): assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files) assert len(files) == 14 + def test_download_dduf_with_custom_pipeline_raises_error(self): + with self.assertRaises(NotImplementedError): + _ = DiffusionPipeline.download( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", custom_pipeline="my_pipeline" + ) + + def test_download_dduf_with_connected_pipeline_raises_error(self): + with self.assertRaises(NotImplementedError): + _ = DiffusionPipeline.download( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True + ) + def test_get_pipeline_class_from_flax(self): flax_config = {"_class_name": "FlaxStableDiffusionPipeline"} config = {"_class_name": "StableDiffusionPipeline"} @@ -1802,6 +1816,55 @@ class PipelineFastTests(unittest.TestCase): sd.maybe_free_model_hooks() assert sd._offload_gpu_id == 5 + @parameterized.expand([torch.float32, torch.float16]) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_load_dduf_from_hub(self, dtype): + with tempfile.TemporaryDirectory() as tmpdir: + pipe = DiffusionPipeline.from_pretrained( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir, torch_dtype=dtype + ).to(torch_device) + out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images + + pipe.save_pretrained(tmpdir) + loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=dtype).to(torch_device) + + out_2 = loaded_pipe( + prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np" + ).images + + self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4)) + + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_load_dduf_from_hub_local_files_only(self): + with tempfile.TemporaryDirectory() as tmpdir: + pipe = DiffusionPipeline.from_pretrained( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir + ).to(torch_device) + out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images + + local_files_pipe = DiffusionPipeline.from_pretrained( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir, local_files_only=True + ).to(torch_device) + out_2 = local_files_pipe( + prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np" + ).images + + self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4)) + + def test_dduf_raises_error_with_custom_pipeline(self): + with self.assertRaises(NotImplementedError): + _ = DiffusionPipeline.from_pretrained( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", custom_pipeline="my_pipeline" + ) + + def test_dduf_raises_error_with_connected_pipeline(self): + with self.assertRaises(NotImplementedError): + _ = DiffusionPipeline.from_pretrained( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True + ) + def test_wrong_model(self): tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") with self.assertRaises(ValueError) as error_context: @@ -1812,6 +1875,27 @@ class PipelineFastTests(unittest.TestCase): assert "is of type" in str(error_context.exception) assert "but should be" in str(error_context.exception) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_dduf_load_sharded_checkpoint_diffusion_model(self): + with tempfile.TemporaryDirectory() as tmpdir: + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-flux-dev-pipe-sharded-checkpoint-DDUF", + dduf_file="tiny-flux-dev-pipe-sharded-checkpoint.dduf", + cache_dir=tmpdir, + ).to(torch_device) + + out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images + + pipe.save_pretrained(tmpdir) + loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir).to(torch_device) + + out_2 = loaded_pipe( + prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np" + ).images + + self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4)) + @slow @require_torch_gpu diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f5494fbade..83b628e09f 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -43,7 +43,9 @@ from diffusers.utils.testing_utils import ( CaptureLogger, require_accelerate_version_greater, require_accelerator, + require_hf_hub_version_greater, require_torch, + require_transformers_version_greater, skip_mps, torch_device, ) @@ -986,6 +988,8 @@ class PipelineTesterMixin: test_xformers_attention = True + supports_dduf = True + def get_generator(self, seed): device = torch_device if torch_device != "mps" else "cpu" generator = torch.Generator(device).manual_seed(seed) @@ -1990,6 +1994,39 @@ class PipelineTesterMixin: ) ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self, atol=1e-4, rtol=1e-4): + if not self.supports_dduf: + return + + from huggingface_hub import export_folder_as_dduf + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device="cpu") + inputs.pop("generator") + inputs["generator"] = torch.manual_seed(0) + + pipeline_out = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf") + pipe.save_pretrained(tmpdir, safe_serialization=True) + export_folder_as_dduf(dduf_filename, folder_path=tmpdir) + loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device) + + inputs["generator"] = torch.manual_seed(0) + loaded_pipeline_out = loaded_pipe(**inputs)[0] + + if isinstance(pipeline_out, np.ndarray) and isinstance(loaded_pipeline_out, np.ndarray): + assert np.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol) + elif isinstance(pipeline_out, torch.Tensor) and isinstance(loaded_pipeline_out, torch.Tensor): + assert torch.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase): diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index dfc3acc0c0..23a6cd6663 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -66,6 +66,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa "super_res_num_inference_steps", ] test_xformers_attention = False + supports_dduf = False @property def text_embedder_hidden_size(self): diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py index 2e0ba1cfb8..310e46a2e8 100644 --- a/tests/pipelines/unidiffuser/test_unidiffuser.py +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -86,6 +86,8 @@ class UniDiffuserPipelineFastTests( # vae_latents, not latents, is the argument that corresponds to VAE latent inputs image_latents_params = frozenset(["vae_latents"]) + supports_dduf = False + def get_dummy_components(self): unet = UniDiffuserModel.from_pretrained( "hf-internal-testing/unidiffuser-diffusers-test",