1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add load textual inversion embeddings to stable diffusion (#2009)

* add load textual inversion embeddings draft

* fix quality

* fix typo

* make fix copies

* move to textual inversion mixin

* make it accept from sd-concept library

* accept list of paths to embeddings

* fix styling of stable diffusion pipeline

* add dummy TextualInversionMixin

* add docstring to textualinversionmixin

* add load textual inversion embeddings draft

* fix quality

* fix typo

* make fix copies

* move to textual inversion mixin

* make it accept from sd-concept library

* accept list of paths to embeddings

* fix styling of stable diffusion pipeline

* add dummy TextualInversionMixin

* add docstring to textualinversionmixin

* add case for parsing embedding from auto1111 UI format

Co-authored-by: Evan Jones <evan.a.jones3@gmail.com>
Co-authored-by: Ana Tamais <aninhamoraestamais@gmail.com>

* fix style after rebase

* move textual inversion mixin to loaders

* move mixin inheritance to DiffusionPipeline from StableDiffusionPipeline)

* update dummy class name

* addressed allo comments

* fix old dangling import

* fix style

* proposal

* remove bogus

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>

* finish

* make style

* up

* fix code quality

* fix code quality - again

* fix code quality - 3

* fix alt diffusion code quality

* fix model editing pipeline

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Finish

---------

Co-authored-by: Evan Jones <evan.a.jones3@gmail.com>
Co-authored-by: Ana Tamais <aninhamoraestamais@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Will Berman <wlbberman@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Pi Esposito
2023-03-30 14:08:39 -03:00
committed by GitHub
parent 1d033a95f6
commit a937e1b594
28 changed files with 766 additions and 168 deletions

View File

@@ -109,6 +109,7 @@ try:
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .loaders import TextualInversionLoaderMixin
from .pipelines import (
AltDiffusionImg2ImgPipeline,
AltDiffusionPipeline,

View File

@@ -13,18 +13,28 @@
# limitations under the License.
import os
from collections import defaultdict
from typing import Callable, Dict, Union
from typing import Callable, Dict, List, Optional, Union
import torch
from .models.attention_processor import LoRAAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
deprecate,
is_safetensors_available,
is_transformers_available,
logging,
)
if is_safetensors_available():
import safetensors
if is_transformers_available():
from transformers import PreTrainedModel, PreTrainedTokenizer
logger = logging.get_logger(__name__)
@@ -32,6 +42,9 @@ logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
@@ -123,13 +136,6 @@ class UNet2DConditionLoadersMixin:
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.
</Tip>
"""
@@ -292,5 +298,272 @@ class UNet2DConditionLoadersMixin:
# Save the model
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
class TextualInversionLoaderMixin:
r"""
Mixin class for loading textual inversion tokens and embeddings to the tokenizer and text encoder.
"""
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: PreTrainedTokenizer):
r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
Parameters:
prompt (`str` or list of `str`):
The prompt or prompts to guide the image generation.
tokenizer (`PreTrainedTokenizer`):
The tokenizer responsible for encoding the prompt into input tokens.
Returns:
`str` or list of `str`: The converted prompt
"""
if not isinstance(prompt, List):
prompts = [prompt]
else:
prompts = prompt
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
if not isinstance(prompt, List):
return prompts[0]
return prompts
def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer):
r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
Parameters:
prompt (`str`):
The prompt to guide the image generation.
tokenizer (`PreTrainedTokenizer`):
The tokenizer responsible for encoding the prompt into input tokens.
Returns:
`str`: The converted prompt
"""
tokens = tokenizer.tokenize(prompt)
for token in tokens:
if token in tokenizer.added_tokens_encoder:
replacement = token
i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
replacement += f"{token}_{i}"
i += 1
prompt = prompt.replace(token, replacement)
return prompt
def load_textual_inversion(
self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs
):
r"""
Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and
`Automatic1111` formats are supported.
<Tip warning={true}>
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids should have an organization name, like
`"sd-concepts-library/low-poly-hd-logos-icons"`.
- A path to a *directory* containing textual inversion weights, e.g.
`./my_text_inversion_directory/`.
weight_name (`str`, *optional*):
Name of a custom weight file. This should be used in two cases:
- The saved textual inversion file is in `diffusers` format, but was saved under a specific weight
name, such as `text_inv.bin`.
- The saved textual inversion file is in the "Automatic1111" form.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `diffusers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
"""
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
raise ValueError(
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
raise ValueError(
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
)
allow_pickle = False
if use_safetensors is None:
use_safetensors = is_safetensors_available()
allow_pickle = True
user_agent = {
"file_type": "text_inversion",
"framework": "pytorch",
}
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e
model_file = None
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
# 2. Load token and embedding correcly from file
if isinstance(state_dict, torch.Tensor):
if token is None:
raise ValueError(
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
)
embedding = state_dict
elif len(state_dict) == 1:
# diffusers
loaded_token, embedding = next(iter(state_dict.items()))
elif "string_to_param" in state_dict:
# A1111
loaded_token = state_dict["name"]
embedding = state_dict["string_to_param"]["*"]
if token is not None and loaded_token != token:
logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
else:
token = loaded_token
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
# 3. Make sure we don't mess up the tokenizer or text encoder
vocab = self.tokenizer.get_vocab()
if token in vocab:
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
elif f"{token}_1" in vocab:
multi_vector_tokens = [token]
i = 1
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
multi_vector_tokens.append(f"{token}_{i}")
i += 1
raise ValueError(
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
)
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
if is_multi_vector:
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
embeddings = [e for e in embedding] # noqa: C416
else:
tokens = [token]
embeddings = [embedding] if len(embedding.shape) > 1 else [embedding[0]]
# add tokens and get ids
self.tokenizer.add_tokens(tokens)
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
# resize token embeddings and set new embeddings
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
for token_id, embedding in zip(token_ids, embeddings):
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
logger.info("Loaded textual inversion embedding for {token}.")

View File

@@ -16,27 +16,22 @@
import inspect
import os
import warnings
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from packaging import version
from requests import HTTPError
from torch import Tensor, device
from .. import __version__
from ..utils import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
_add_variant,
_get_model_file,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
@@ -144,15 +139,6 @@ def _load_state_dict_into_model(model_to_load, state_dict):
return error_msgs
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
class ModelMixin(torch.nn.Module):
r"""
Base class for all models.
@@ -789,121 +775,3 @@ class ModelMixin(torch.nn.Module):
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
else:
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
def _get_model_file(
pretrained_model_name_or_path,
*,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
commit_hash=None,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
return model_file
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
return model_file
else:
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
else:
# 1. First check if deprecated way of loading from branches is used
if (
revision in DEPRECATED_REVISION_ARGS
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0")
):
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=_add_variant(weights_name, revision),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision or commit_hash,
)
warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
return model_file
except: # noqa: E722
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
FutureWarning,
)
try:
# 2. Load model file as usual
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision or commit_hash,
)
return model_file
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
)

View File

@@ -22,6 +22,7 @@ from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
@@ -49,7 +50,7 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionPipeline(DiffusionPipeline):
class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Alt Diffusion.
@@ -312,6 +313,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -372,6 +377,10 @@ class AltDiffusionPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -25,6 +25,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
@@ -88,7 +89,7 @@ def preprocess(image):
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image to image generation using Alt Diffusion.
@@ -322,6 +323,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -382,6 +387,10 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -24,6 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.utils import is_accelerate_available, is_accelerate_version
from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
@@ -118,7 +119,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta):
return noise
class CycleDiffusionPipeline(DiffusionPipeline):
class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image to image generation using Stable Diffusion.
@@ -338,6 +339,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -398,6 +403,10 @@ class CycleDiffusionPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -20,6 +20,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -52,7 +53,7 @@ EXAMPLE_DOC_STRING = """
"""
class StableDiffusionPipeline(DiffusionPipeline):
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -315,6 +316,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -375,6 +380,10 @@ class StableDiffusionPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -21,6 +21,7 @@ import torch
from torch.nn import functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention
from ...schedulers import KarrasDiffusionSchedulers
@@ -159,7 +160,7 @@ class AttendExciteAttnProcessor:
return hidden_states
class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion and Attend and Excite.
@@ -335,6 +336,10 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -395,6 +400,10 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -23,6 +23,7 @@ import torch
from torch import nn
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.controlnet import ControlNetOutput
from ...models.modeling_utils import ModelMixin
@@ -146,7 +147,7 @@ class MultiControlNetModel(ModelMixin):
return down_block_res_samples, mid_block_res_sample
class StableDiffusionControlNetPipeline(DiffusionPipeline):
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -354,6 +355,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -414,6 +419,10 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -23,6 +23,7 @@ from packaging import version
from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation
from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
@@ -54,7 +55,7 @@ def preprocess(image):
return image
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image to image generation using Stable Diffusion.
@@ -200,6 +201,10 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -260,6 +265,10 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -23,6 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -91,7 +92,7 @@ def preprocess(image):
return image
class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
class StableDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image to image generation using Stable Diffusion.
@@ -329,6 +330,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -389,6 +394,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -22,6 +22,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
@@ -137,7 +138,7 @@ def prepare_mask_and_masked_image(image, mask):
return mask, masked_image
class StableDiffusionInpaintPipeline(DiffusionPipeline):
class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
@@ -381,6 +382,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -441,6 +446,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -22,6 +22,7 @@ from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -81,7 +82,7 @@ def preprocess_mask(mask, scale_factor=8):
return mask
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*.
@@ -317,6 +318,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -377,6 +382,10 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -20,6 +20,7 @@ import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -60,7 +61,7 @@ def preprocess(image):
return image
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion.
@@ -511,6 +512,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -571,6 +576,10 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -18,6 +18,7 @@ from typing import Callable, List, Optional, Union
import torch
from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser
from ...loaders import TextualInversionLoaderMixin
from ...pipelines import DiffusionPipeline
from ...schedulers import LMSDiscreteScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor
@@ -41,7 +42,7 @@ class ModelWrapper:
return self.model(*args, encoder_hidden_states=encoder_hidden_states, **kwargs).sample
class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -238,6 +239,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -298,6 +303,10 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -18,6 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...schedulers.scheduling_utils import SchedulerMixin
@@ -52,7 +53,7 @@ EXAMPLE_DOC_STRING = """
"""
class StableDiffusionModelEditingPipeline(DiffusionPipeline):
class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models".
@@ -266,6 +267,10 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -326,6 +331,10 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -17,6 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, PNDMScheduler
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
@@ -47,7 +48,7 @@ EXAMPLE_DOC_STRING = """
"""
class StableDiffusionPanoramaPipeline(DiffusionPipeline):
class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image
Generation".
@@ -230,6 +231,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -290,6 +295,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -28,6 +28,7 @@ from transformers import (
CLIPTokenizer,
)
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention
from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler
@@ -50,7 +51,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Pix2PixInversionPipelineOutput(BaseOutput):
class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin):
"""
Output class for Stable Diffusion pipelines.
@@ -470,6 +471,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -530,6 +535,10 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -19,6 +19,7 @@ import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
@@ -87,7 +88,7 @@ class CrossAttnStoreProcessor:
# Modified to get self-attention guidance scale in this paper (https://arxiv.org/pdf/2210.00939.pdf) as an input
class StableDiffusionSAGPipeline(DiffusionPipeline):
class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -247,6 +248,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -307,6 +312,10 @@ class StableDiffusionSAGPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -20,6 +20,7 @@ import PIL
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
@@ -50,7 +51,7 @@ def preprocess(image):
return image
class StableDiffusionUpscalePipeline(DiffusionPipeline):
class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-guided image super-resolution using Stable Diffusion 2.
@@ -194,6 +195,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -254,6 +259,10 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -19,6 +19,7 @@ import torch
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
@@ -47,7 +48,7 @@ EXAMPLE_DOC_STRING = """
"""
class StableUnCLIPPipeline(DiffusionPipeline):
class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
"""
Pipeline for text-to-image generation using stable unCLIP.
@@ -367,6 +368,10 @@ class StableUnCLIPPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -427,6 +432,10 @@ class StableUnCLIPPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -21,6 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
from diffusers.utils.import_utils import is_accelerate_available
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
@@ -60,7 +61,7 @@ EXAMPLE_DOC_STRING = """
"""
class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
"""
Pipeline for text-guided image to image generation using stable unCLIP.
@@ -267,6 +268,10 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -327,6 +332,10 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -19,6 +19,7 @@ import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
@@ -72,7 +73,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) -
return images
class TextToVideoSDPipeline(DiffusionPipeline):
class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
r"""
Pipeline for text-to-video generation.
@@ -256,6 +257,10 @@ class TextToVideoSDPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
@@ -316,6 +321,10 @@ class TextToVideoSDPipeline(DiffusionPipeline):
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,

View File

@@ -37,6 +37,8 @@ from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import (
HF_HUB_OFFLINE,
_add_variant,
_get_model_file,
extract_commit_hash,
http_user_agent,
)

View File

@@ -2,6 +2,21 @@
from ..utils import DummyObject, requires_backends
class TextualInversionLoaderMixin(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -18,16 +18,30 @@ import os
import re
import sys
import traceback
import warnings
from pathlib import Path
from typing import Dict, Optional, Union
from uuid import uuid4
from huggingface_hub import HfFolder, ModelCard, ModelCardData, whoami
from huggingface_hub import HfFolder, ModelCard, ModelCardData, hf_hub_download, whoami
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import is_jinja_available
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
is_jinja_available,
)
from packaging import version
from requests import HTTPError
from .. import __version__
from .constants import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT
from .constants import (
DEPRECATED_REVISION_ARGS,
DIFFUSERS_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
)
from .import_utils import (
ENV_VARS_TRUE_VALUES,
_flax_version,
@@ -215,3 +229,130 @@ if cache_version < 1:
f"There was a problem when trying to write in your cache folder ({DIFFUSERS_CACHE}). Please, ensure "
"the directory exists and can be written to."
)
def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
if variant is not None:
splits = weights_name.split(".")
splits = splits[:-1] + [variant] + splits[-1:]
weights_name = ".".join(splits)
return weights_name
def _get_model_file(
pretrained_model_name_or_path,
*,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
commit_hash=None,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
return model_file
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
return model_file
else:
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
else:
# 1. First check if deprecated way of loading from branches is used
if (
revision in DEPRECATED_REVISION_ARGS
and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME)
and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0")
):
try:
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=_add_variant(weights_name, revision),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision or commit_hash,
)
warnings.warn(
f"Loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` is deprecated. Loading instead from `revision='main'` with `variant={revision}`. Loading model variants via `revision='{revision}'` will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
)
return model_file
except: # noqa: E722
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have a {_add_variant(weights_name, revision)} file in the 'main' branch of {pretrained_model_name_or_path}. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {_add_variant(weights_name, revision)}' so that the correct variant file can be added.",
FutureWarning,
)
try:
# 2. Load model file as usual
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision or commit_hash,
)
return model_file
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
except HTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
)

View File

@@ -21,6 +21,7 @@ import unittest
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
@@ -886,6 +887,32 @@ class StableDiffusionPipelineSlowTests(unittest.TestCase):
assert mem_bytes_slicing < mem_bytes_offloaded
assert mem_bytes_slicing < 3 * 10**9
def test_stable_diffusion_textual_inversion(self):
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
pipe.load_textual_inversion("sd-concepts-library/low-poly-hd-logos-icons")
a111_file = hf_hub_download("hf-internal-testing/text_inv_embedding_a1111_format", "winter_style.pt")
a111_file_neg = hf_hub_download(
"hf-internal-testing/text_inv_embedding_a1111_format", "winter_style_negative.pt"
)
pipe.load_textual_inversion(a111_file)
pipe.load_textual_inversion(a111_file_neg)
pipe.to("cuda")
generator = torch.Generator(device="cpu").manual_seed(1)
prompt = "An logo of a turtle in strong Style-Winter with <low-poly-hd-logos-icons>"
neg_prompt = "Style-Winter-neg"
image = pipe(prompt=prompt, negative_prompt=neg_prompt, generator=generator, output_type="np").images[0]
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text_inv/winter_logo_style.npy"
)
max_diff = np.abs(expected_image - image).max()
assert max_diff < 5e-3
@nightly
@require_torch_gpu

View File

@@ -362,6 +362,97 @@ class DownloadTests(unittest.TestCase):
diffusers.utils.import_utils._safetensors_available = True
def test_text_inversion_download(self):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pipe = pipe.to(torch_device)
num_tokens = len(pipe.tokenizer)
# single token load local
with tempfile.TemporaryDirectory() as tmpdirname:
ten = {"<*>": torch.ones((32,))}
torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
pipe.load_textual_inversion(tmpdirname)
token = pipe.tokenizer.convert_tokens_to_ids("<*>")
assert token == num_tokens, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32
assert pipe._maybe_convert_prompt("<*>", pipe.tokenizer) == "<*>"
prompt = "hey <*>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
# single token load local with weight name
with tempfile.TemporaryDirectory() as tmpdirname:
ten = {"<**>": 2 * torch.ones((1, 32))}
torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
pipe.load_textual_inversion(tmpdirname, weight_name="learned_embeds.bin")
token = pipe.tokenizer.convert_tokens_to_ids("<**>")
assert token == num_tokens + 1, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
assert pipe._maybe_convert_prompt("<**>", pipe.tokenizer) == "<**>"
prompt = "hey <**>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
# multi token load
with tempfile.TemporaryDirectory() as tmpdirname:
ten = {"<***>": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])}
torch.save(ten, os.path.join(tmpdirname, "learned_embeds.bin"))
pipe.load_textual_inversion(tmpdirname)
token = pipe.tokenizer.convert_tokens_to_ids("<***>")
token_1 = pipe.tokenizer.convert_tokens_to_ids("<***>_1")
token_2 = pipe.tokenizer.convert_tokens_to_ids("<***>_2")
assert token == num_tokens + 2, "Added token must be at spot `num_tokens`"
assert token_1 == num_tokens + 3, "Added token must be at spot `num_tokens`"
assert token_2 == num_tokens + 4, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
assert pipe._maybe_convert_prompt("<***>", pipe.tokenizer) == "<***><***>_1<***>_2"
prompt = "hey <***>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
# multi token load a1111
with tempfile.TemporaryDirectory() as tmpdirname:
ten = {
"string_to_param": {
"*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])
},
"name": "<****>",
}
torch.save(ten, os.path.join(tmpdirname, "a1111.bin"))
pipe.load_textual_inversion(tmpdirname, weight_name="a1111.bin")
token = pipe.tokenizer.convert_tokens_to_ids("<****>")
token_1 = pipe.tokenizer.convert_tokens_to_ids("<****>_1")
token_2 = pipe.tokenizer.convert_tokens_to_ids("<****>_2")
assert token == num_tokens + 5, "Added token must be at spot `num_tokens`"
assert token_1 == num_tokens + 6, "Added token must be at spot `num_tokens`"
assert token_2 == num_tokens + 7, "Added token must be at spot `num_tokens`"
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
assert pipe._maybe_convert_prompt("<****>", pipe.tokenizer) == "<****><****>_1<****>_2"
prompt = "hey <****>"
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
assert out.shape == (1, 128, 128, 3)
class CustomPipelineTests(unittest.TestCase):
def test_load_custom_pipeline(self):