mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add load textual inversion embeddings to stable diffusion (#2009)
* add load textual inversion embeddings draft * fix quality * fix typo * make fix copies * move to textual inversion mixin * make it accept from sd-concept library * accept list of paths to embeddings * fix styling of stable diffusion pipeline * add dummy TextualInversionMixin * add docstring to textualinversionmixin * add load textual inversion embeddings draft * fix quality * fix typo * make fix copies * move to textual inversion mixin * make it accept from sd-concept library * accept list of paths to embeddings * fix styling of stable diffusion pipeline * add dummy TextualInversionMixin * add docstring to textualinversionmixin * add case for parsing embedding from auto1111 UI format Co-authored-by: Evan Jones <evan.a.jones3@gmail.com> Co-authored-by: Ana Tamais <aninhamoraestamais@gmail.com> * fix style after rebase * move textual inversion mixin to loaders * move mixin inheritance to DiffusionPipeline from StableDiffusionPipeline) * update dummy class name * addressed allo comments * fix old dangling import * fix style * proposal * remove bogus * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com> * finish * make style * up * fix code quality * fix code quality - again * fix code quality - 3 * fix alt diffusion code quality * fix model editing pipeline * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Finish --------- Co-authored-by: Evan Jones <evan.a.jones3@gmail.com> Co-authored-by: Ana Tamais <aninhamoraestamais@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -109,6 +109,7 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .loaders import TextualInversionLoaderMixin
|
||||
from .pipelines import (
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
AltDiffusionPipeline,
|
||||
|
||||
@@ -13,18 +13,28 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, Union
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .models.attention_processor import LoRAAttnProcessor
|
||||
from .models.modeling_utils import _get_model_file
|
||||
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_safetensors_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
import safetensors
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -32,6 +42,9 @@ logger = logging.get_logger(__name__)
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
|
||||
TEXT_INVERSION_NAME = "learned_embeds.bin"
|
||||
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
||||
|
||||
|
||||
class AttnProcsLayers(torch.nn.Module):
|
||||
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
||||
@@ -123,13 +136,6 @@ class UNet2DConditionLoadersMixin:
|
||||
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
||||
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
||||
this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
@@ -292,5 +298,272 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# Save the model
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
|
||||
class TextualInversionLoaderMixin:
|
||||
r"""
|
||||
Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder.
|
||||
"""
|
||||
|
||||
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: PreTrainedTokenizer):
|
||||
r"""
|
||||
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
|
||||
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
|
||||
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
||||
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
|
||||
|
||||
Parameters:
|
||||
prompt (`str` or list of `str`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
The tokenizer responsible for encoding the prompt into input tokens.
|
||||
|
||||
Returns:
|
||||
`str` or list of `str`: The converted prompt
|
||||
"""
|
||||
if not isinstance(prompt, List):
|
||||
prompts = [prompt]
|
||||
else:
|
||||
prompts = prompt
|
||||
|
||||
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
|
||||
|
||||
if not isinstance(prompt, List):
|
||||
return prompts[0]
|
||||
|
||||
return prompts
|
||||
|
||||
def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer):
|
||||
r"""
|
||||
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
|
||||
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
|
||||
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
||||
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
|
||||
|
||||
Parameters:
|
||||
prompt (`str`):
|
||||
The prompt to guide the image generation.
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
The tokenizer responsible for encoding the prompt into input tokens.
|
||||
|
||||
Returns:
|
||||
`str`: The converted prompt
|
||||
"""
|
||||
tokens = tokenizer.tokenize(prompt)
|
||||
for token in tokens:
|
||||
if token in tokenizer.added_tokens_encoder:
|
||||
replacement = token
|
||||
i = 1
|
||||
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
||||
replacement += f"{token}_{i}"
|
||||
i += 1
|
||||
|
||||
prompt = prompt.replace(token, replacement)
|
||||
|
||||
return prompt
|
||||
|
||||
def load_textual_inversion(
|
||||
self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs
|
||||
):
|
||||
r"""
|
||||
Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and
|
||||
`Automatic1111` formats are supported.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids should have an organization name, like
|
||||
`"sd-concepts-library/low-poly-hd-logos-icons"`.
|
||||
- A path to a *directory* containing textual inversion weights, e.g.
|
||||
`./my_text_inversion_directory/`.
|
||||
weight_name (`str`, *optional*):
|
||||
Name of a custom weight file. This should be used in two cases:
|
||||
|
||||
- The saved textual inversion file is in `diffusers` format, but was saved under a specific weight
|
||||
name, such as `text_inv.bin`.
|
||||
- The saved textual inversion file is in the "Automatic1111" form.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||||
huggingface.co or downloaded locally), you can specify the folder name here.
|
||||
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information.
|
||||
|
||||
<Tip>
|
||||
|
||||
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
||||
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
)
|
||||
|
||||
if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
)
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ValueError(
|
||||
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = is_safetensors_available()
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "text_inversion",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
# 1. Load textual inversion file
|
||||
model_file = None
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except Exception as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
|
||||
model_file = None
|
||||
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
|
||||
# 2. Load token and embedding correcly from file
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
if token is None:
|
||||
raise ValueError(
|
||||
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
||||
)
|
||||
embedding = state_dict
|
||||
elif len(state_dict) == 1:
|
||||
# diffusers
|
||||
loaded_token, embedding = next(iter(state_dict.items()))
|
||||
elif "string_to_param" in state_dict:
|
||||
# A1111
|
||||
loaded_token = state_dict["name"]
|
||||
embedding = state_dict["string_to_param"]["*"]
|
||||
|
||||
if token is not None and loaded_token != token:
|
||||
logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
||||
else:
|
||||
token = loaded_token
|
||||
|
||||
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
|
||||
|
||||
# 3. Make sure we don't mess up the tokenizer or text encoder
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
if token in vocab:
|
||||
raise ValueError(
|
||||
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
||||
)
|
||||
elif f"{token}_1" in vocab:
|
||||
multi_vector_tokens = [token]
|
||||
i = 1
|
||||
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
|
||||
multi_vector_tokens.append(f"{token}_{i}")
|
||||
i += 1
|
||||
|
||||
raise ValueError(
|
||||
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
||||
)
|
||||
|
||||
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
||||
|
||||
if is_multi_vector:
|
||||
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
||||
embeddings = [e for e in embedding] # noqa: C416
|
||||
else:
|
||||
tokens = [token]
|
||||
embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]]
|
||||
|
||||
# add tokens and get ids
|
||||
self.tokenizer.add_tokens(tokens)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
# resize token embeddings and set new embeddings
|
||||
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
||||
for token_id, embedding in zip(token_ids, embeddings):
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
|
||||
logger.info("Loaded textual inversion embedding for {token}.")
|
||||
|
||||
@@ -16,27 +16,22 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
|
||||
from packaging import version
|
||||
from requests import HTTPError
|
||||
from torch import Tensor, device
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_CACHE,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HF_HUB_OFFLINE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
_add_variant,
|
||||
_get_model_file,
|
||||
is_accelerate_available,
|
||||
is_safetensors_available,
|
||||
is_torch_version,
|
||||
@@ -144,15 +139,6 @@ def _load_state_dict_into_model(model_to_load, state_dict):
|
||||
return error_msgs
|
||||
|
||||
|
||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None:
|
||||
splits = weights_name.split(".")
|
||||
splits = splits[:-1] + [variant] + splits[-1:]
|
||||
weights_name = ".".join(splits)
|
||||
|
||||
return weights_name
|
||||
|
||||
|
||||
class ModelMixin(torch.nn.Module):
|
||||
r"""
|
||||
Base class for all models.
|
||||
@@ -789,121 +775,3 @@ class ModelMixin(torch.nn.Module):
|
||||
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
||||
else:
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
||||
|
||||
|
||||
def _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
*,
|
||||
weights_name,
|
||||
subfolder,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
resume_download,
|
||||
local_files_only,
|
||||
use_auth_token,
|
||||
user_agent,
|
||||
revision,
|
||||
commit_hash=None,
|
||||
):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
return pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
||||
# Load from a PyTorch checkpoint
|
||||
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
|
||||
return model_file
|
||||
elif subfolder is not None and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||
):
|
||||
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||
return model_file
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
else:
|
||||
# 1. First check if deprecated way of loading from branches is used
|
||||
if (
|
||||
revision in DEPRECATED_REVISION_ARGS
|
||||
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
|
||||
and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0")
|
||||
):
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=_add_variant(weights_name, revision),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision or commit_hash,
|
||||
)
|
||||
warnings.warn(
|
||||
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return model_file
|
||||
except: # noqa: E722
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
|
||||
FutureWarning,
|
||||
)
|
||||
try:
|
||||
# 2. Load model file as usual
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=weights_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision or commit_hash,
|
||||
)
|
||||
return model_file
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {weights_name} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {weights_name}"
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from transformers import CLIPImageProcessor, XLMRobertaTokenizer
|
||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
|
||||
@@ -49,7 +50,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
|
||||
class AltDiffusionPipeline(DiffusionPipeline):
|
||||
class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Alt Diffusion.
|
||||
|
||||
@@ -312,6 +313,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -372,6 +377,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -25,6 +25,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
|
||||
@@ -88,7 +89,7 @@ def preprocess(image):
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
|
||||
class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Alt Diffusion.
|
||||
|
||||
@@ -322,6 +323,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -382,6 +387,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -24,6 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from diffusers.utils import is_accelerate_available, is_accelerate_version
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
|
||||
@@ -118,7 +119,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
|
||||
return noise
|
||||
|
||||
|
||||
class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Stable Diffusion.
|
||||
|
||||
@@ -338,6 +339,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -398,6 +403,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -20,6 +20,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -52,7 +53,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionPipeline(DiffusionPipeline):
|
||||
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
@@ -315,6 +316,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -375,6 +380,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
from torch.nn import functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import Attention
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -159,7 +160,7 @@ class AttendExciteAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||
class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite.
|
||||
|
||||
@@ -335,6 +336,10 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -395,6 +400,10 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -23,6 +23,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...models.controlnet import ControlNetOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
@@ -146,7 +147,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
|
||||
|
||||
class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
||||
|
||||
@@ -354,6 +355,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -414,6 +419,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -23,6 +23,7 @@ from packaging import version
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
|
||||
@@ -54,7 +55,7 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Stable Diffusion.
|
||||
|
||||
@@ -200,6 +201,10 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -260,6 +265,10 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -23,6 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -91,7 +92,7 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image to image generation using Stable Diffusion.
|
||||
|
||||
@@ -329,6 +330,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -389,6 +394,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -22,6 +22,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
@@ -137,7 +138,7 @@ def prepare_mask_and_masked_image(image, mask):
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
||||
|
||||
@@ -381,6 +382,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -441,6 +446,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -22,6 +22,7 @@ from packaging import version
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -81,7 +82,7 @@ def preprocess_mask(mask, scale_factor=8):
|
||||
return mask
|
||||
|
||||
|
||||
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
|
||||
|
||||
@@ -317,6 +318,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -377,6 +382,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -20,6 +20,7 @@ import PIL
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -60,7 +61,7 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
|
||||
|
||||
@@ -511,6 +512,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -571,6 +576,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union
|
||||
import torch
|
||||
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...pipelines import DiffusionPipeline
|
||||
from ...schedulers import LMSDiscreteScheduler
|
||||
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
@@ -41,7 +42,7 @@ class ModelWrapper:
|
||||
return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
|
||||
|
||||
|
||||
class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
@@ -238,6 +239,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -298,6 +303,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import torch
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import PNDMScheduler
|
||||
from ...schedulers.scheduling_utils import SchedulerMixin
|
||||
@@ -52,7 +53,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionModelEditingPipeline(DiffusionPipeline):
|
||||
class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models".
|
||||
|
||||
@@ -266,6 +267,10 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -326,6 +331,10 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler, PNDMScheduler
|
||||
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
|
||||
@@ -47,7 +48,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||
class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image
|
||||
Generation".
|
||||
@@ -230,6 +231,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -290,6 +295,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -28,6 +28,7 @@ from transformers import (
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import Attention
|
||||
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
|
||||
@@ -50,7 +51,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pix2PixInversionPipelineOutput(BaseOutput):
|
||||
class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin):
|
||||
"""
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
@@ -470,6 +471,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -530,6 +535,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
|
||||
@@ -87,7 +88,7 @@ class CrossAttnStoreProcessor:
|
||||
|
||||
|
||||
# Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input
|
||||
class StableDiffusionSAGPipeline(DiffusionPipeline):
|
||||
class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
@@ -247,6 +248,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -307,6 +312,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -20,6 +20,7 @@ import PIL
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
|
||||
@@ -50,7 +51,7 @@ def preprocess(image):
|
||||
return image
|
||||
|
||||
|
||||
class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-guided image super-resolution using Stable Diffusion 2.
|
||||
|
||||
@@ -194,6 +195,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -254,6 +259,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
|
||||
from ...models.embeddings import get_timestep_embedding
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -47,7 +48,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableUnCLIPPipeline(DiffusionPipeline):
|
||||
class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
"""
|
||||
Pipeline for text-to-image generation using stable unCLIP.
|
||||
|
||||
@@ -367,6 +368,10 @@ class StableUnCLIPPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -427,6 +432,10 @@ class StableUnCLIPPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -21,6 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.embeddings import get_timestep_embedding
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -60,7 +61,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
|
||||
class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
"""
|
||||
Pipeline for text-guided image to image generation using stable unCLIP.
|
||||
|
||||
@@ -267,6 +268,10 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -327,6 +332,10 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -19,6 +19,7 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet3DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -72,7 +73,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -
|
||||
return images
|
||||
|
||||
|
||||
class TextToVideoSDPipeline(DiffusionPipeline):
|
||||
class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation.
|
||||
|
||||
@@ -256,6 +257,10 @@ class TextToVideoSDPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
@@ -316,6 +321,10 @@ class TextToVideoSDPipeline(DiffusionPipeline):
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
|
||||
@@ -37,6 +37,8 @@ from .doc_utils import replace_example_docstring
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .hub_utils import (
|
||||
HF_HUB_OFFLINE,
|
||||
_add_variant,
|
||||
_get_model_file,
|
||||
extract_commit_hash,
|
||||
http_user_agent,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class TextualInversionLoaderMixin(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -18,16 +18,30 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami
|
||||
from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami
|
||||
from huggingface_hub.file_download import REGEX_COMMIT_HASH
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
from huggingface_hub.utils import (
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
is_jinja_available,
|
||||
)
|
||||
from packaging import version
|
||||
from requests import HTTPError
|
||||
|
||||
from .. import __version__
|
||||
from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT
|
||||
from .constants import (
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
from .import_utils import (
|
||||
ENV_VARS_TRUE_VALUES,
|
||||
_flax_version,
|
||||
@@ -215,3 +229,130 @@ if cache_version < 1:
|
||||
f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure "
|
||||
"the directory exists and can be written to."
|
||||
)
|
||||
|
||||
|
||||
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
|
||||
if variant is not None:
|
||||
splits = weights_name.split(".")
|
||||
splits = splits[:-1] + [variant] + splits[-1:]
|
||||
weights_name = ".".join(splits)
|
||||
|
||||
return weights_name
|
||||
|
||||
|
||||
def _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
*,
|
||||
weights_name,
|
||||
subfolder,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
resume_download,
|
||||
local_files_only,
|
||||
use_auth_token,
|
||||
user_agent,
|
||||
revision,
|
||||
commit_hash=None,
|
||||
):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
return pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
||||
# Load from a PyTorch checkpoint
|
||||
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
|
||||
return model_file
|
||||
elif subfolder is not None and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||
):
|
||||
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||
return model_file
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
else:
|
||||
# 1. First check if deprecated way of loading from branches is used
|
||||
if (
|
||||
revision in DEPRECATED_REVISION_ARGS
|
||||
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
|
||||
and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0")
|
||||
):
|
||||
try:
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=_add_variant(weights_name, revision),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision or commit_hash,
|
||||
)
|
||||
warnings.warn(
|
||||
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return model_file
|
||||
except: # noqa: E722
|
||||
warnings.warn(
|
||||
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
|
||||
FutureWarning,
|
||||
)
|
||||
try:
|
||||
# 2. Load model file as usual
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=weights_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision or commit_hash,
|
||||
)
|
||||
return model_file
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {weights_name} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {weights_name}"
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
@@ -886,6 +887,32 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
|
||||
assert mem_bytes_slicing < mem_bytes_offloaded
|
||||
assert mem_bytes_slicing < 3 * 10**9
|
||||
|
||||
def test_stable_diffusion_textual_inversion(self):
|
||||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
|
||||
pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons")
|
||||
|
||||
a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt")
|
||||
a111_file_neg = hf_hub_download(
|
||||
"hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt"
|
||||
)
|
||||
pipe.load_textual_inversion(a111_file)
|
||||
pipe.load_textual_inversion(a111_file_neg)
|
||||
pipe.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(1)
|
||||
|
||||
prompt = "An logo of a turtle in strong Style-Winter with <low-poly-hd-logos-icons>"
|
||||
neg_prompt = "Style-Winter-neg"
|
||||
|
||||
image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0]
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy"
|
||||
)
|
||||
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 5e-3
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -362,6 +362,97 @@ class DownloadTests(unittest.TestCase):
|
||||
|
||||
diffusers.utils.import_utils._safetensors_available = True
|
||||
|
||||
def test_text_inversion_download(self):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
num_tokens = len(pipe.tokenizer)
|
||||
|
||||
# single token load local
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ten = {"<*>": torch.ones((32,))}
|
||||
torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
|
||||
|
||||
pipe.load_textual_inversion(tmpdirname)
|
||||
|
||||
token = pipe.tokenizer.convert_tokens_to_ids("<*>")
|
||||
assert token == num_tokens, "Added token must be at spot `num_tokens`"
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32
|
||||
assert pipe._maybe_convert_prompt("<*>", pipe.tokenizer) == "<*>"
|
||||
|
||||
prompt = "hey <*>"
|
||||
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
|
||||
assert out.shape == (1, 128, 128, 3)
|
||||
|
||||
# single token load local with weight name
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ten = {"<**>": 2 * torch.ones((1, 32))}
|
||||
torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
|
||||
|
||||
pipe.load_textual_inversion(tmpdirname, weight_name="learned_embeds.bin")
|
||||
|
||||
token = pipe.tokenizer.convert_tokens_to_ids("<**>")
|
||||
assert token == num_tokens + 1, "Added token must be at spot `num_tokens`"
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
|
||||
assert pipe._maybe_convert_prompt("<**>", pipe.tokenizer) == "<**>"
|
||||
|
||||
prompt = "hey <**>"
|
||||
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
|
||||
assert out.shape == (1, 128, 128, 3)
|
||||
|
||||
# multi token load
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ten = {"<***>": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])}
|
||||
torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
|
||||
|
||||
pipe.load_textual_inversion(tmpdirname)
|
||||
|
||||
token = pipe.tokenizer.convert_tokens_to_ids("<***>")
|
||||
token_1 = pipe.tokenizer.convert_tokens_to_ids("<***>_1")
|
||||
token_2 = pipe.tokenizer.convert_tokens_to_ids("<***>_2")
|
||||
|
||||
assert token == num_tokens + 2, "Added token must be at spot `num_tokens`"
|
||||
assert token_1 == num_tokens + 3, "Added token must be at spot `num_tokens`"
|
||||
assert token_2 == num_tokens + 4, "Added token must be at spot `num_tokens`"
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
|
||||
assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***><***>_1<***>_2"
|
||||
|
||||
prompt = "hey <***>"
|
||||
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
|
||||
assert out.shape == (1, 128, 128, 3)
|
||||
|
||||
# multi token load a1111
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
ten = {
|
||||
"string_to_param": {
|
||||
"*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])
|
||||
},
|
||||
"name": "<****>",
|
||||
}
|
||||
torch.save(ten, os.path.join(tmpdirname, "a1111.bin"))
|
||||
|
||||
pipe.load_textual_inversion(tmpdirname, weight_name="a1111.bin")
|
||||
|
||||
token = pipe.tokenizer.convert_tokens_to_ids("<****>")
|
||||
token_1 = pipe.tokenizer.convert_tokens_to_ids("<****>_1")
|
||||
token_2 = pipe.tokenizer.convert_tokens_to_ids("<****>_2")
|
||||
|
||||
assert token == num_tokens + 5, "Added token must be at spot `num_tokens`"
|
||||
assert token_1 == num_tokens + 6, "Added token must be at spot `num_tokens`"
|
||||
assert token_2 == num_tokens + 7, "Added token must be at spot `num_tokens`"
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
|
||||
assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****><****>_1<****>_2"
|
||||
|
||||
prompt = "hey <****>"
|
||||
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
|
||||
assert out.shape == (1, 128, 128, 3)
|
||||
|
||||
|
||||
class CustomPipelineTests(unittest.TestCase):
|
||||
def test_load_custom_pipeline(self):
|
||||
|
||||
Reference in New Issue
Block a user