mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[FEAT] DDUF format (#10037)
* load and save dduf archive * style * switch to zip uncompressed * updates * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * first draft * remove print * switch to dduf_file for consistency * switch to huggingface hub api * fix log * add a basic test * Update src/diffusers/configuration_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * fix * fix variant * change saving logic * DDUF - Load transformers components manually (#10171) * update hfh version * Load transformers components manually * load encoder from_pretrained with state_dict * working version with transformers and tokenizer ! * add generation_config case * fix tests * remove saving for now * typing * need next version from transformers * Update src/diffusers/configuration_utils.py Co-authored-by: Lucain <lucain@huggingface.co> * check path corectly * Apply suggestions from code review Co-authored-by: Lucain <lucain@huggingface.co> * udapte * typing * remove check for subfolder * quality * revert setup changes * oups * more readable condition * add loading from the hub test * add basic docs. * Apply suggestions from code review Co-authored-by: Lucain <lucain@huggingface.co> * add example * add * make functions private * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * minor. * fixes * fix * change the precdence of parameterized. * error out when custom pipeline is passed with dduf_file. * updates * fix * updates * fixes * updates * fix xfail condition. * fix xfail * fixes * sharded checkpoint compat * add test for sharded checkpoint * add suggestions * Update src/diffusers/models/model_loading_utils.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * from suggestions * add class attributes to flag dduf tests * last one * fix logic * remove comment * revert changes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Lucain <lucain@huggingface.co> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -240,6 +240,46 @@ Benefits of using a single-file layout include:
|
||||
1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout.
|
||||
2. Easier to manage (download and share) a single file.
|
||||
|
||||
### DDUF
|
||||
|
||||
> [!WARNING]
|
||||
> DDUF is an experimental file format and APIs related to it can change in the future.
|
||||
|
||||
DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format.
|
||||
|
||||
Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf).
|
||||
|
||||
Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`].
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
image = pipe(
|
||||
"photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations.
|
||||
|
||||
```py
|
||||
from huggingface_hub import export_folder_as_dduf
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
|
||||
save_folder = "flux-dev"
|
||||
pipe.save_pretrained("flux-dev")
|
||||
export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)
|
||||
|
||||
> [!TIP]
|
||||
> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure.
|
||||
|
||||
## Convert layout and files
|
||||
|
||||
Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem.
|
||||
|
||||
2
setup.py
2
setup.py
@@ -101,7 +101,7 @@ _deps = [
|
||||
"filelock",
|
||||
"flax>=0.4.1",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub>=0.23.2",
|
||||
"huggingface-hub>=0.27.0",
|
||||
"requests-mock==1.10.0",
|
||||
"importlib_metadata",
|
||||
"invisible-watermark>=0.2.0",
|
||||
|
||||
@@ -24,10 +24,10 @@ import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import create_repo, hf_hub_download
|
||||
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
@@ -347,6 +347,7 @@ class ConfigMixin:
|
||||
_ = kwargs.pop("mirror", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
|
||||
user_agent = {**user_agent, "file_type": "config"}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
@@ -358,8 +359,15 @@ class ConfigMixin:
|
||||
"`self.config_name` is not defined. Note that one should not load a config from "
|
||||
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
||||
)
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
# Custom path for now
|
||||
if dduf_entries:
|
||||
if subfolder is not None:
|
||||
raise ValueError(
|
||||
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
|
||||
"Please check the DDUF structure"
|
||||
)
|
||||
config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries)
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if subfolder is not None and os.path.isfile(
|
||||
@@ -426,10 +434,8 @@ class ConfigMixin:
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {cls.config_name} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries)
|
||||
|
||||
commit_hash = extract_commit_hash(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
@@ -552,9 +558,14 @@ class ConfigMixin:
|
||||
return init_dict, unused_kwargs, hidden_config_dict
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
def _dict_from_json_file(
|
||||
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
|
||||
):
|
||||
if dduf_entries:
|
||||
text = dduf_entries[json_file].read_text()
|
||||
else:
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -616,6 +627,20 @@ class ConfigMixin:
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
@classmethod
|
||||
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
|
||||
# paths inside a DDUF file must always be "/"
|
||||
config_file = (
|
||||
cls.config_name
|
||||
if pretrained_model_name_or_path == ""
|
||||
else "/".join([pretrained_model_name_or_path, cls.config_name])
|
||||
)
|
||||
if config_file not in dduf_entries:
|
||||
raise ValueError(
|
||||
f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
|
||||
)
|
||||
return config_file
|
||||
|
||||
|
||||
def register_to_config(init):
|
||||
r"""
|
||||
|
||||
@@ -9,7 +9,7 @@ deps = {
|
||||
"filelock": "filelock",
|
||||
"flax": "flax>=0.4.1",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.23.2",
|
||||
"huggingface-hub": "huggingface-hub>=0.27.0",
|
||||
"requests-mock": "requests-mock==1.10.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"invisible-watermark": "invisible-watermark>=0.2.0",
|
||||
|
||||
@@ -20,10 +20,11 @@ import os
|
||||
from array import array
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import DDUFEntry
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
|
||||
from ..utils import (
|
||||
@@ -132,7 +133,10 @@ def _fetch_remapped_cls_from_config(config, old_class):
|
||||
|
||||
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
variant: Optional[str] = None,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
disable_mmap: bool = False,
|
||||
):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
@@ -144,6 +148,10 @@ def load_state_dict(
|
||||
try:
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
if dduf_entries:
|
||||
# tensors are loaded on cpu
|
||||
with dduf_entries[checkpoint_file].as_mmap() as mm:
|
||||
return safetensors.torch.load(mm)
|
||||
if disable_mmap:
|
||||
return safetensors.torch.load(open(checkpoint_file, "rb").read())
|
||||
else:
|
||||
@@ -284,6 +292,7 @@ def _fetch_index_file(
|
||||
revision,
|
||||
user_agent,
|
||||
commit_hash,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
):
|
||||
if is_local:
|
||||
index_file = Path(
|
||||
@@ -309,8 +318,10 @@ def _fetch_index_file(
|
||||
subfolder=None,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
index_file = Path(index_file)
|
||||
if not dduf_entries:
|
||||
index_file = Path(index_file)
|
||||
except (EntryNotFoundError, EnvironmentError):
|
||||
index_file = None
|
||||
|
||||
@@ -319,7 +330,9 @@ def _fetch_index_file(
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
|
||||
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
|
||||
def _merge_sharded_checkpoints(
|
||||
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None
|
||||
):
|
||||
weight_map = sharded_metadata.get("weight_map", None)
|
||||
if weight_map is None:
|
||||
raise KeyError("'weight_map' key not found in the shard index file.")
|
||||
@@ -332,14 +345,23 @@ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
|
||||
# Load tensors from each unique file
|
||||
for file_name in files_to_load:
|
||||
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
|
||||
if not os.path.exists(part_file_path):
|
||||
raise FileNotFoundError(f"Part file {file_name} not found.")
|
||||
if dduf_entries:
|
||||
if part_file_path not in dduf_entries:
|
||||
raise FileNotFoundError(f"Part file {file_name} not found.")
|
||||
else:
|
||||
if not os.path.exists(part_file_path):
|
||||
raise FileNotFoundError(f"Part file {file_name} not found.")
|
||||
|
||||
if is_safetensors:
|
||||
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
|
||||
for tensor_key in f.keys():
|
||||
if tensor_key in weight_map:
|
||||
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
|
||||
if dduf_entries:
|
||||
with dduf_entries[part_file_path].as_mmap() as mm:
|
||||
tensors = safetensors.torch.load(mm)
|
||||
merged_state_dict.update(tensors)
|
||||
else:
|
||||
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
|
||||
for tensor_key in f.keys():
|
||||
if tensor_key in weight_map:
|
||||
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
|
||||
else:
|
||||
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
|
||||
|
||||
@@ -360,6 +382,7 @@ def _fetch_index_file_legacy(
|
||||
revision,
|
||||
user_agent,
|
||||
commit_hash,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
):
|
||||
if is_local:
|
||||
index_file = Path(
|
||||
@@ -400,6 +423,7 @@ def _fetch_index_file_legacy(
|
||||
subfolder=None,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
index_file = Path(index_file)
|
||||
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|
||||
|
||||
@@ -23,11 +23,11 @@ import re
|
||||
from collections import OrderedDict
|
||||
from functools import partial, wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
|
||||
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import Tensor, nn
|
||||
|
||||
@@ -607,6 +607,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
allow_pickle = False
|
||||
@@ -700,6 +701,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
dduf_entries=dduf_entries,
|
||||
**kwargs,
|
||||
)
|
||||
# no in-place modification of the original config.
|
||||
@@ -776,13 +778,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"revision": revision,
|
||||
"user_agent": user_agent,
|
||||
"commit_hash": commit_hash,
|
||||
"dduf_entries": dduf_entries,
|
||||
}
|
||||
index_file = _fetch_index_file(**index_file_kwargs)
|
||||
# In case the index file was not found we still have to consider the legacy format.
|
||||
# this becomes applicable when the variant is not None.
|
||||
if variant is not None and (index_file is None or not os.path.exists(index_file)):
|
||||
index_file = _fetch_index_file_legacy(**index_file_kwargs)
|
||||
if index_file is not None and index_file.is_file():
|
||||
if index_file is not None and (dduf_entries or index_file.is_file()):
|
||||
is_sharded = True
|
||||
|
||||
if is_sharded and from_flax:
|
||||
@@ -811,6 +814,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||
else:
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded:
|
||||
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -822,10 +826,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
user_agent=user_agent,
|
||||
revision=revision,
|
||||
subfolder=subfolder or "",
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
# TODO: https://github.com/huggingface/diffusers/issues/10013
|
||||
if hf_quantizer is not None:
|
||||
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
|
||||
if hf_quantizer is not None or dduf_entries:
|
||||
model_file = _merge_sharded_checkpoints(
|
||||
sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries
|
||||
)
|
||||
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
|
||||
is_sharded = False
|
||||
|
||||
@@ -843,6 +850,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
|
||||
except IOError as e:
|
||||
@@ -866,6 +874,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
@@ -887,7 +896,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
||||
else:
|
||||
param_device = torch.device(torch.cuda.current_device())
|
||||
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
|
||||
state_dict = load_state_dict(
|
||||
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
|
||||
)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
# move the params from meta device to cpu
|
||||
@@ -983,7 +994,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap)
|
||||
state_dict = load_state_dict(
|
||||
model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap
|
||||
)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
|
||||
@@ -12,19 +12,19 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import ModelCard, model_info
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download
|
||||
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
|
||||
from packaging import version
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import (
|
||||
@@ -38,14 +38,16 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transformers_model_from_dduf
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
|
||||
@@ -627,6 +629,7 @@ def load_sub_model(
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
use_safetensors: bool,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]],
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
|
||||
@@ -663,7 +666,7 @@ def load_sub_model(
|
||||
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
load_method = _get_load_method(class_obj, load_method_name, is_dduf=dduf_entries is not None)
|
||||
|
||||
# add kwargs to loading method
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
@@ -721,7 +724,10 @@ def load_sub_model(
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
if dduf_entries:
|
||||
loading_kwargs["dduf_entries"] = dduf_entries
|
||||
loaded_sub_model = load_method(name, **loading_kwargs)
|
||||
elif os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||
else:
|
||||
# else load from the root directory
|
||||
@@ -746,6 +752,22 @@ def load_sub_model(
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> Callable:
|
||||
"""
|
||||
Return the method to load the sub model.
|
||||
|
||||
In practice, this method will return the `"from_pretrained"` (or `load_method_name`) method of the class object
|
||||
except if loading from a DDUF checkpoint. In that case, transformers models and tokenizers have a specific loading
|
||||
method that we need to use.
|
||||
"""
|
||||
if is_dduf:
|
||||
if issubclass(class_obj, PreTrainedTokenizerBase):
|
||||
return lambda *args, **kwargs: _load_tokenizer_from_dduf(class_obj, *args, **kwargs)
|
||||
if issubclass(class_obj, PreTrainedModel):
|
||||
return lambda *args, **kwargs: _load_transformers_model_from_dduf(class_obj, *args, **kwargs)
|
||||
return getattr(class_obj, load_method_name)
|
||||
|
||||
|
||||
def _fetch_class_library_tuple(module):
|
||||
# import it here to avoid circular import
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
@@ -968,3 +990,70 @@ def _get_ignore_patterns(
|
||||
)
|
||||
|
||||
return ignore_patterns
|
||||
|
||||
|
||||
def _download_dduf_file(
|
||||
pretrained_model_name: str,
|
||||
dduf_file: str,
|
||||
pipeline_class_name: str,
|
||||
cache_dir: str,
|
||||
proxies: str,
|
||||
local_files_only: bool,
|
||||
token: str,
|
||||
revision: str,
|
||||
):
|
||||
model_info_call_error = None
|
||||
if not local_files_only:
|
||||
try:
|
||||
info = model_info(pretrained_model_name, token=token, revision=revision)
|
||||
except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
|
||||
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
|
||||
local_files_only = True
|
||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||
|
||||
if (
|
||||
not local_files_only
|
||||
and dduf_file is not None
|
||||
and dduf_file not in (sibling.rfilename for sibling in info.siblings)
|
||||
):
|
||||
raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.")
|
||||
|
||||
try:
|
||||
user_agent = {"pipeline_class": pipeline_class_name, "dduf": True}
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name,
|
||||
cache_dir=cache_dir,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
allow_patterns=[dduf_file],
|
||||
user_agent=user_agent,
|
||||
)
|
||||
return cached_folder
|
||||
except FileNotFoundError:
|
||||
# Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache.
|
||||
# This can happen in two cases:
|
||||
# 1. If the user passed `local_files_only=True` => we raise the error directly
|
||||
# 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error
|
||||
if model_info_call_error is None:
|
||||
# 1. user passed `local_files_only=True`
|
||||
raise
|
||||
else:
|
||||
# 2. we forced `local_files_only=True` when `model_info` failed
|
||||
raise EnvironmentError(
|
||||
f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred"
|
||||
" while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
|
||||
" above."
|
||||
) from model_info_call_error
|
||||
|
||||
|
||||
def _maybe_raise_error_for_incorrect_transformers(config_dict):
|
||||
has_transformers_component = False
|
||||
for k in config_dict:
|
||||
if isinstance(config_dict[k], list):
|
||||
has_transformers_component = config_dict[k][0] == "transformers"
|
||||
if has_transformers_component:
|
||||
break
|
||||
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
|
||||
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")
|
||||
|
||||
@@ -29,10 +29,12 @@ import PIL.Image
|
||||
import requests
|
||||
import torch
|
||||
from huggingface_hub import (
|
||||
DDUFEntry,
|
||||
ModelCard,
|
||||
create_repo,
|
||||
hf_hub_download,
|
||||
model_info,
|
||||
read_dduf_file,
|
||||
snapshot_download,
|
||||
)
|
||||
from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
|
||||
@@ -72,6 +74,7 @@ from .pipeline_loading_utils import (
|
||||
CONNECTED_PIPES_KEYS,
|
||||
CUSTOM_PIPELINE_FILE_NAME,
|
||||
LOADABLE_CLASSES,
|
||||
_download_dduf_file,
|
||||
_fetch_class_library_tuple,
|
||||
_get_custom_components_and_folders,
|
||||
_get_custom_pipeline_class,
|
||||
@@ -79,6 +82,7 @@ from .pipeline_loading_utils import (
|
||||
_get_ignore_patterns,
|
||||
_get_pipeline_class,
|
||||
_identify_model_variants,
|
||||
_maybe_raise_error_for_incorrect_transformers,
|
||||
_maybe_raise_warning_for_inpainting,
|
||||
_resolve_custom_pipeline_and_cls,
|
||||
_unwrap_model,
|
||||
@@ -218,6 +222,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -531,6 +536,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
|
||||
saved using
|
||||
[`~DiffusionPipeline.save_pretrained`].
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
@@ -625,6 +631,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
dduf_file(`str`, *optional*):
|
||||
Load weights from the specified dduf file.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -674,6 +682,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
dduf_file = kwargs.pop("dduf_file", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
@@ -722,6 +731,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
if dduf_file:
|
||||
if custom_pipeline:
|
||||
raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
|
||||
if load_connected_pipeline:
|
||||
raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.")
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
@@ -744,6 +759,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
custom_pipeline=custom_pipeline,
|
||||
custom_revision=custom_revision,
|
||||
variant=variant,
|
||||
dduf_file=dduf_file,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -765,7 +781,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
logger.warning(warn_msg)
|
||||
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
dduf_entries = None
|
||||
if dduf_file:
|
||||
dduf_file_path = os.path.join(cached_folder, dduf_file)
|
||||
dduf_entries = read_dduf_file(dduf_file_path)
|
||||
# The reader contains already all the files needed, no need to check it again
|
||||
cached_folder = ""
|
||||
|
||||
config_dict = cls.load_config(cached_folder, dduf_entries=dduf_entries)
|
||||
|
||||
if dduf_file:
|
||||
_maybe_raise_error_for_incorrect_transformers(config_dict)
|
||||
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
config_dict.pop("_ignore_files", None)
|
||||
@@ -943,6 +969,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
use_safetensors=use_safetensors,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
@@ -1256,6 +1283,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
dduf_file(`str`, *optional*):
|
||||
Load weights from the specified DDUF file.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
@@ -1296,6 +1325,23 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
dduf_file: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_file", None)
|
||||
|
||||
if dduf_file:
|
||||
if custom_pipeline:
|
||||
raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.")
|
||||
if load_connected_pipeline:
|
||||
raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.")
|
||||
return _download_dduf_file(
|
||||
pretrained_model_name=pretrained_model_name,
|
||||
dduf_file=dduf_file,
|
||||
pipeline_class_name=cls.__name__,
|
||||
cache_dir=cache_dir,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
@@ -1375,7 +1421,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
@@ -1471,7 +1516,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
# retrieve pipeline class from local file
|
||||
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
|
||||
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
|
||||
|
||||
|
||||
121
src/diffusers/pipelines/transformers_loading_utils.py
Normal file
121
src/diffusers/pipelines/transformers_loading_utils.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import os
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from huggingface_hub import DDUFEntry
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..utils import is_safetensors_available, is_transformers_available, is_transformers_version
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors.torch
|
||||
|
||||
|
||||
def _load_tokenizer_from_dduf(
|
||||
cls: "PreTrainedTokenizer", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs
|
||||
) -> "PreTrainedTokenizer":
|
||||
"""
|
||||
Load a tokenizer from a DDUF archive.
|
||||
|
||||
In practice, `transformers` do not provide a way to load a tokenizer from a DDUF archive. This function is a
|
||||
workaround by extracting the tokenizer files from the DDUF archive and loading the tokenizer from the extracted
|
||||
files. There is an extra cost of extracting the files, but of limited impact as the tokenizer files are usually
|
||||
small-ish.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for entry_name, entry in dduf_entries.items():
|
||||
if entry_name.startswith(name + "/"):
|
||||
tmp_entry_path = os.path.join(tmp_dir, *entry_name.split("/"))
|
||||
# need to create intermediary directory if they don't exist
|
||||
os.makedirs(os.path.dirname(tmp_entry_path), exist_ok=True)
|
||||
with open(tmp_entry_path, "wb") as f:
|
||||
with entry.as_mmap() as mm:
|
||||
f.write(mm)
|
||||
return cls.from_pretrained(os.path.dirname(tmp_entry_path), **kwargs)
|
||||
|
||||
|
||||
def _load_transformers_model_from_dduf(
|
||||
cls: "PreTrainedModel", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs
|
||||
) -> "PreTrainedModel":
|
||||
"""
|
||||
Load a transformers model from a DDUF archive.
|
||||
|
||||
In practice, `transformers` do not provide a way to load a model from a DDUF archive. This function is a workaround
|
||||
by instantiating a model from the config file and loading the weights from the DDUF archive directly.
|
||||
"""
|
||||
config_file = dduf_entries.get(f"{name}/config.json")
|
||||
if config_file is None:
|
||||
raise EnvironmentError(
|
||||
f"Could not find a config.json file for component {name} in DDUF file (contains {dduf_entries.keys()})."
|
||||
)
|
||||
generation_config = dduf_entries.get(f"{name}/generation_config.json", None)
|
||||
|
||||
weight_files = [
|
||||
entry
|
||||
for entry_name, entry in dduf_entries.items()
|
||||
if entry_name.startswith(f"{name}/") and entry_name.endswith(".safetensors")
|
||||
]
|
||||
if not weight_files:
|
||||
raise EnvironmentError(
|
||||
f"Could not find any weight file for component {name} in DDUF file (contains {dduf_entries.keys()})."
|
||||
)
|
||||
if not is_safetensors_available():
|
||||
raise EnvironmentError(
|
||||
"Safetensors is not available, cannot load model from DDUF. Please `pip install safetensors`."
|
||||
)
|
||||
if is_transformers_version("<", "4.47.0"):
|
||||
raise ImportError(
|
||||
"You need to install `transformers>4.47.0` in order to load a transformers model from a DDUF file. "
|
||||
"You can install it with: `pip install --upgrade transformers`"
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
from transformers import AutoConfig, GenerationConfig
|
||||
|
||||
tmp_config_file = os.path.join(tmp_dir, "config.json")
|
||||
with open(tmp_config_file, "w") as f:
|
||||
f.write(config_file.read_text())
|
||||
config = AutoConfig.from_pretrained(tmp_config_file)
|
||||
if generation_config is not None:
|
||||
tmp_generation_config_file = os.path.join(tmp_dir, "generation_config.json")
|
||||
with open(tmp_generation_config_file, "w") as f:
|
||||
f.write(generation_config.read_text())
|
||||
generation_config = GenerationConfig.from_pretrained(tmp_generation_config_file)
|
||||
state_dict = {}
|
||||
with contextlib.ExitStack() as stack:
|
||||
for entry in tqdm(weight_files, desc="Loading state_dict"): # Loop over safetensors files
|
||||
# Memory-map the safetensors file
|
||||
mmap = stack.enter_context(entry.as_mmap())
|
||||
# Load tensors from the memory-mapped file
|
||||
tensors = safetensors.torch.load(mmap)
|
||||
# Update the state dictionary with tensors
|
||||
state_dict.update(tensors)
|
||||
return cls.from_pretrained(
|
||||
pretrained_model_name_or_path=None,
|
||||
config=config,
|
||||
generation_config=generation_config,
|
||||
state_dict=state_dict,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -70,6 +70,7 @@ from .import_utils import (
|
||||
is_gguf_available,
|
||||
is_gguf_version,
|
||||
is_google_colab,
|
||||
is_hf_hub_version,
|
||||
is_inflect_available,
|
||||
is_invisible_watermark_available,
|
||||
is_k_diffusion_available,
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import Dict, List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from huggingface_hub import (
|
||||
DDUFEntry,
|
||||
ModelCard,
|
||||
ModelCardData,
|
||||
create_repo,
|
||||
@@ -291,9 +292,26 @@ def _get_model_file(
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
commit_hash: Optional[str] = None,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
|
||||
if dduf_entries:
|
||||
if subfolder is not None:
|
||||
raise ValueError(
|
||||
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
|
||||
"Please check the DDUF structure"
|
||||
)
|
||||
model_file = (
|
||||
weights_name
|
||||
if pretrained_model_name_or_path == ""
|
||||
else "/".join([pretrained_model_name_or_path, weights_name])
|
||||
)
|
||||
if model_file in dduf_entries:
|
||||
return model_file
|
||||
else:
|
||||
raise EnvironmentError(f"Error no file named {weights_name} found in archive {dduf_entries.keys()}.")
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
return pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
||||
@@ -419,6 +437,7 @@ def _get_checkpoint_shard_files(
|
||||
user_agent=None,
|
||||
revision=None,
|
||||
subfolder="",
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
):
|
||||
"""
|
||||
For a given model:
|
||||
@@ -430,11 +449,18 @@ def _get_checkpoint_shard_files(
|
||||
For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
|
||||
index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
|
||||
"""
|
||||
if not os.path.isfile(index_filename):
|
||||
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
|
||||
if dduf_entries:
|
||||
if index_filename not in dduf_entries:
|
||||
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
|
||||
else:
|
||||
if not os.path.isfile(index_filename):
|
||||
raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
|
||||
|
||||
with open(index_filename, "r") as f:
|
||||
index = json.loads(f.read())
|
||||
if dduf_entries:
|
||||
index = json.loads(dduf_entries[index_filename].read_text())
|
||||
else:
|
||||
with open(index_filename, "r") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
original_shard_filenames = sorted(set(index["weight_map"].values()))
|
||||
sharded_metadata = index["metadata"]
|
||||
@@ -448,6 +474,8 @@ def _get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames
|
||||
)
|
||||
return shards_path, sharded_metadata
|
||||
elif dduf_entries:
|
||||
return shards_path, sharded_metadata
|
||||
|
||||
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
|
||||
allow_patterns = original_shard_filenames
|
||||
|
||||
@@ -115,6 +115,13 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_transformers_available = False
|
||||
|
||||
_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None
|
||||
try:
|
||||
_hf_hub_version = importlib_metadata.version("huggingface_hub")
|
||||
logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_hf_hub_available = False
|
||||
|
||||
|
||||
_inflect_available = importlib.util.find_spec("inflect") is not None
|
||||
try:
|
||||
@@ -767,6 +774,21 @@ def is_transformers_version(operation: str, version: str):
|
||||
return compare_versions(parse(_transformers_version), operation, version)
|
||||
|
||||
|
||||
def is_hf_hub_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Hugging Face Hub version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _hf_hub_available:
|
||||
return False
|
||||
return compare_versions(parse(_hf_hub_version), operation, version)
|
||||
|
||||
|
||||
def is_accelerate_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Accelerate version to a given reference with an operation.
|
||||
|
||||
@@ -478,6 +478,18 @@ def require_bitsandbytes_version_greater(bnb_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_hf_hub_version_greater(hf_hub_version):
|
||||
def decorator(test_case):
|
||||
correct_hf_hub_version = version.parse(
|
||||
version.parse(importlib.metadata.version("huggingface_hub")).base_version
|
||||
) > version.parse(hf_hub_version)
|
||||
return unittest.skipUnless(
|
||||
correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}."
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def require_gguf_version_greater_or_equal(gguf_version):
|
||||
def decorator(test_case):
|
||||
correct_gguf_version = is_gguf_available() and version.parse(
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
import gc
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -24,7 +26,9 @@ from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLA
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_hf_hub_version_greater,
|
||||
require_torch_gpu,
|
||||
require_transformers_version_greater,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
@@ -297,6 +301,35 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
@require_hf_hub_version_greater("0.26.5")
|
||||
@require_transformers_version_greater("4.47.1")
|
||||
def test_save_load_dduf(self):
|
||||
# reimplement because it needs `enable_tiling()` on the loaded pipe.
|
||||
from huggingface_hub import export_folder_as_dduf
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device="cpu")
|
||||
inputs.pop("generator")
|
||||
inputs["generator"] = torch.manual_seed(0)
|
||||
|
||||
pipeline_out = pipe(**inputs)[0].cpu()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf")
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=True)
|
||||
export_folder_as_dduf(dduf_filename, folder_path=tmpdir)
|
||||
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device)
|
||||
|
||||
loaded_pipe.vae.enable_tiling()
|
||||
inputs["generator"] = torch.manual_seed(0)
|
||||
loaded_pipeline_out = loaded_pipe(**inputs)[0].cpu()
|
||||
|
||||
assert np.allclose(pipeline_out, loaded_pipeline_out)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -63,6 +63,8 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
]
|
||||
)
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -70,6 +70,8 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
]
|
||||
)
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = AudioLDM2UNet2DConditionModel(
|
||||
|
||||
@@ -60,6 +60,8 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"prompt_reps",
|
||||
]
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
|
||||
@@ -291,6 +291,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
@@ -523,6 +525,8 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -68,6 +68,8 @@ class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.Tes
|
||||
"prompt_reps",
|
||||
]
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
|
||||
@@ -198,6 +198,8 @@ class StableDiffusionMultiControlNetPipelineFastTests(
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -257,6 +257,8 @@ class MultiControlNetInpaintPipelineFastTests(
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -78,6 +78,8 @@ class ControlNetPipelineSDXLFastTests(
|
||||
}
|
||||
)
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -487,6 +487,8 @@ class StableDiffusionXLMultiControlNetPipelineFastTests(
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
@@ -692,6 +694,8 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -26,7 +26,9 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
load_numpy,
|
||||
require_accelerator,
|
||||
require_hf_hub_version_greater,
|
||||
require_torch_gpu,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -89,6 +91,11 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3)
|
||||
|
||||
@require_hf_hub_version_greater("0.26.5")
|
||||
@require_transformers_version_greater("4.47.1")
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -26,7 +26,9 @@ from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
load_numpy,
|
||||
require_accelerator,
|
||||
require_hf_hub_version_greater,
|
||||
require_torch_gpu,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -100,6 +102,11 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
@require_hf_hub_version_greater("0.26.5")
|
||||
@require_transformers_version_greater("4.47.1")
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -26,7 +26,9 @@ from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
load_numpy,
|
||||
require_accelerator,
|
||||
require_hf_hub_version_greater,
|
||||
require_torch_gpu,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -97,6 +99,11 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
@require_hf_hub_version_greater("0.26.5")
|
||||
@require_transformers_version_greater("4.47.1")
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -26,7 +26,9 @@ from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
load_numpy,
|
||||
require_accelerator,
|
||||
require_hf_hub_version_greater,
|
||||
require_torch_gpu,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -97,6 +99,11 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin,
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
@require_hf_hub_version_greater("0.26.5")
|
||||
@require_transformers_version_greater("4.47.1")
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -26,7 +26,9 @@ from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
load_numpy,
|
||||
require_accelerator,
|
||||
require_hf_hub_version_greater,
|
||||
require_torch_gpu,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -99,6 +101,11 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
@require_hf_hub_version_greater("0.26.5")
|
||||
@require_transformers_version_greater("4.47.1")
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -26,7 +26,9 @@ from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
load_numpy,
|
||||
require_accelerator,
|
||||
require_hf_hub_version_greater,
|
||||
require_torch_gpu,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -92,6 +94,11 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi
|
||||
expected_max_diff=1e-2,
|
||||
)
|
||||
|
||||
@require_hf_hub_version_greater("0.26.5")
|
||||
@require_transformers_version_greater("4.47.1")
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-2, rtol=1e-2)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -59,6 +59,8 @@ class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unit
|
||||
# No `output_type`.
|
||||
required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"])
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
scheduler = DDIMScheduler(
|
||||
|
||||
@@ -204,6 +204,8 @@ class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
]
|
||||
test_xformers_attention = False
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
dummy = Dummies()
|
||||
return dummy.get_dummy_components()
|
||||
|
||||
@@ -52,6 +52,8 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
]
|
||||
test_xformers_attention = True
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
dummy = Dummies()
|
||||
prior_dummy = PriorDummies()
|
||||
@@ -160,6 +162,8 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te
|
||||
]
|
||||
test_xformers_attention = False
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
dummy = Img2ImgDummies()
|
||||
prior_dummy = PriorDummies()
|
||||
@@ -269,6 +273,8 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te
|
||||
]
|
||||
test_xformers_attention = False
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
dummy = InpaintDummies()
|
||||
prior_dummy = PriorDummies()
|
||||
|
||||
@@ -226,6 +226,8 @@ class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
]
|
||||
test_xformers_attention = False
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
dummies = Dummies()
|
||||
return dummies.get_dummy_components()
|
||||
|
||||
@@ -220,6 +220,8 @@ class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
]
|
||||
test_xformers_attention = False
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
dummies = Dummies()
|
||||
return dummies.get_dummy_components()
|
||||
|
||||
@@ -184,6 +184,8 @@ class KandinskyPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
]
|
||||
test_xformers_attention = False
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
dummy = Dummies()
|
||||
return dummy.get_dummy_components()
|
||||
|
||||
@@ -57,6 +57,8 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
test_xformers_attention = True
|
||||
callback_cfg_params = ["image_embds"]
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
dummy = Dummies()
|
||||
prior_dummy = PriorDummies()
|
||||
@@ -181,6 +183,8 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest
|
||||
test_xformers_attention = False
|
||||
callback_cfg_params = ["image_embds"]
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
dummy = Img2ImgDummies()
|
||||
prior_dummy = PriorDummies()
|
||||
@@ -302,6 +306,8 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest
|
||||
]
|
||||
test_xformers_attention = False
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
dummy = InpaintDummies()
|
||||
prior_dummy = PriorDummies()
|
||||
|
||||
@@ -186,6 +186,8 @@ class KandinskyV22PriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
callback_cfg_params = ["prompt_embeds", "text_encoder_hidden_states", "text_mask"]
|
||||
test_xformers_attention = False
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
dummies = Dummies()
|
||||
return dummies.get_dummy_components()
|
||||
|
||||
@@ -59,6 +59,8 @@ class KandinskyV22PriorEmb2EmbPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
]
|
||||
test_xformers_attention = False
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 32
|
||||
|
||||
@@ -47,6 +47,8 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -51,6 +51,8 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
# Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -31,6 +31,8 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = LuminaNextDiT2DModel(
|
||||
|
||||
@@ -65,6 +65,8 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
]
|
||||
)
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -56,6 +56,8 @@ class KolorsPAGPipelineFastTests(
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"})
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
# Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -53,6 +53,8 @@ class SanaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_xformers_attention = False
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = SanaTransformer2DModel(
|
||||
|
||||
@@ -82,6 +82,8 @@ class StableDiffusionXLPAGImg2ImgPipelineFastTests(
|
||||
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
|
||||
)
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
# based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_img2img_pipeline.get_dummy_components
|
||||
def get_dummy_components(
|
||||
self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False
|
||||
|
||||
@@ -82,6 +82,8 @@ class StableDiffusionXLPAGInpaintPipelineFastTests(
|
||||
{"add_text_embeds", "add_time_ids", "mask", "masked_image_latents"}
|
||||
)
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
# based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipelineFastTests.get_dummy_components
|
||||
def get_dummy_components(
|
||||
self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False
|
||||
|
||||
@@ -46,6 +46,8 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -50,6 +50,8 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
]
|
||||
test_xformers_attention = False
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 16
|
||||
|
||||
@@ -70,6 +70,7 @@ class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
# There is not xformers version of the StableAudioPipeline custom attention processor
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -76,6 +76,8 @@ class StableDiffusionDepth2ImgPipelineFastTests(
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"depth_mask"})
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -389,6 +389,8 @@ class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterM
|
||||
|
||||
|
||||
class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase):
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim)
|
||||
|
||||
|
||||
@@ -66,6 +66,8 @@ class GligenTextImagePipelineFastTests(
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -58,6 +58,8 @@ class StableDiffusionImageVariationPipelineFastTests(
|
||||
# TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_latents_params = frozenset([])
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -422,6 +422,8 @@ class StableDiffusionXLAdapterPipelineFastTests(
|
||||
class StableDiffusionXLMultiAdapterPipelineFastTests(
|
||||
StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, time_cond_proj_dim=None):
|
||||
return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim)
|
||||
|
||||
|
||||
@@ -77,6 +77,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests(
|
||||
{"add_text_embeds", "add_time_ids", "add_neg_time_ids"}
|
||||
)
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -72,6 +72,8 @@ class StableDiffusionXLInpaintPipelineFastTests(
|
||||
}
|
||||
)
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
|
||||
@@ -51,6 +51,8 @@ class StableUnCLIPImg2ImgPipelineFastTests(
|
||||
) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
|
||||
image_latents_params = frozenset([])
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
embedder_hidden_size = 32
|
||||
embedder_projection_dim = embedder_hidden_size
|
||||
|
||||
@@ -58,6 +58,8 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
]
|
||||
)
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNetSpatioTemporalConditionModel(
|
||||
|
||||
@@ -75,9 +75,11 @@ from diffusers.utils.testing_utils import (
|
||||
nightly,
|
||||
require_compel,
|
||||
require_flax,
|
||||
require_hf_hub_version_greater,
|
||||
require_onnxruntime,
|
||||
require_torch_2,
|
||||
require_torch_gpu,
|
||||
require_transformers_version_greater,
|
||||
run_test_in_subprocess,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -981,6 +983,18 @@ class DownloadTests(unittest.TestCase):
|
||||
assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files)
|
||||
assert len(files) == 14
|
||||
|
||||
def test_download_dduf_with_custom_pipeline_raises_error(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
_ = DiffusionPipeline.download(
|
||||
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", custom_pipeline="my_pipeline"
|
||||
)
|
||||
|
||||
def test_download_dduf_with_connected_pipeline_raises_error(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
_ = DiffusionPipeline.download(
|
||||
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True
|
||||
)
|
||||
|
||||
def test_get_pipeline_class_from_flax(self):
|
||||
flax_config = {"_class_name": "FlaxStableDiffusionPipeline"}
|
||||
config = {"_class_name": "StableDiffusionPipeline"}
|
||||
@@ -1802,6 +1816,55 @@ class PipelineFastTests(unittest.TestCase):
|
||||
sd.maybe_free_model_hooks()
|
||||
assert sd._offload_gpu_id == 5
|
||||
|
||||
@parameterized.expand([torch.float32, torch.float16])
|
||||
@require_hf_hub_version_greater("0.26.5")
|
||||
@require_transformers_version_greater("4.47.1")
|
||||
def test_load_dduf_from_hub(self, dtype):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir, torch_dtype=dtype
|
||||
).to(torch_device)
|
||||
out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images
|
||||
|
||||
pipe.save_pretrained(tmpdir)
|
||||
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=dtype).to(torch_device)
|
||||
|
||||
out_2 = loaded_pipe(
|
||||
prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np"
|
||||
).images
|
||||
|
||||
self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@require_hf_hub_version_greater("0.26.5")
|
||||
@require_transformers_version_greater("4.47.1")
|
||||
def test_load_dduf_from_hub_local_files_only(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir
|
||||
).to(torch_device)
|
||||
out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images
|
||||
|
||||
local_files_pipe = DiffusionPipeline.from_pretrained(
|
||||
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir, local_files_only=True
|
||||
).to(torch_device)
|
||||
out_2 = local_files_pipe(
|
||||
prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np"
|
||||
).images
|
||||
|
||||
self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4))
|
||||
|
||||
def test_dduf_raises_error_with_custom_pipeline(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", custom_pipeline="my_pipeline"
|
||||
)
|
||||
|
||||
def test_dduf_raises_error_with_connected_pipeline(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
_ = DiffusionPipeline.from_pretrained(
|
||||
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True
|
||||
)
|
||||
|
||||
def test_wrong_model(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
with self.assertRaises(ValueError) as error_context:
|
||||
@@ -1812,6 +1875,27 @@ class PipelineFastTests(unittest.TestCase):
|
||||
assert "is of type" in str(error_context.exception)
|
||||
assert "but should be" in str(error_context.exception)
|
||||
|
||||
@require_hf_hub_version_greater("0.26.5")
|
||||
@require_transformers_version_greater("4.47.1")
|
||||
def test_dduf_load_sharded_checkpoint_diffusion_model(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-flux-dev-pipe-sharded-checkpoint-DDUF",
|
||||
dduf_file="tiny-flux-dev-pipe-sharded-checkpoint.dduf",
|
||||
cache_dir=tmpdir,
|
||||
).to(torch_device)
|
||||
|
||||
out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images
|
||||
|
||||
pipe.save_pretrained(tmpdir)
|
||||
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir).to(torch_device)
|
||||
|
||||
out_2 = loaded_pipe(
|
||||
prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np"
|
||||
).images
|
||||
|
||||
self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4))
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -43,7 +43,9 @@ from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
require_accelerate_version_greater,
|
||||
require_accelerator,
|
||||
require_hf_hub_version_greater,
|
||||
require_torch,
|
||||
require_transformers_version_greater,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
@@ -986,6 +988,8 @@ class PipelineTesterMixin:
|
||||
|
||||
test_xformers_attention = True
|
||||
|
||||
supports_dduf = True
|
||||
|
||||
def get_generator(self, seed):
|
||||
device = torch_device if torch_device != "mps" else "cpu"
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
@@ -1990,6 +1994,39 @@ class PipelineTesterMixin:
|
||||
)
|
||||
)
|
||||
|
||||
@require_hf_hub_version_greater("0.26.5")
|
||||
@require_transformers_version_greater("4.47.1")
|
||||
def test_save_load_dduf(self, atol=1e-4, rtol=1e-4):
|
||||
if not self.supports_dduf:
|
||||
return
|
||||
|
||||
from huggingface_hub import export_folder_as_dduf
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device="cpu")
|
||||
inputs.pop("generator")
|
||||
inputs["generator"] = torch.manual_seed(0)
|
||||
|
||||
pipeline_out = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf")
|
||||
pipe.save_pretrained(tmpdir, safe_serialization=True)
|
||||
export_folder_as_dduf(dduf_filename, folder_path=tmpdir)
|
||||
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device)
|
||||
|
||||
inputs["generator"] = torch.manual_seed(0)
|
||||
loaded_pipeline_out = loaded_pipe(**inputs)[0]
|
||||
|
||||
if isinstance(pipeline_out, np.ndarray) and isinstance(loaded_pipeline_out, np.ndarray):
|
||||
assert np.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol)
|
||||
elif isinstance(pipeline_out, torch.Tensor) and isinstance(loaded_pipeline_out, torch.Tensor):
|
||||
assert torch.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class PipelinePushToHubTester(unittest.TestCase):
|
||||
|
||||
@@ -66,6 +66,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa
|
||||
"super_res_num_inference_steps",
|
||||
]
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
|
||||
@@ -86,6 +86,8 @@ class UniDiffuserPipelineFastTests(
|
||||
# vae_latents, not latents, is the argument that corresponds to VAE latent inputs
|
||||
image_latents_params = frozenset(["vae_latents"])
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
unet = UniDiffuserModel.from_pretrained(
|
||||
"hf-internal-testing/unidiffuser-diffusers-test",
|
||||
|
||||
Reference in New Issue
Block a user