mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Textual inversion] Relax loading textual inversion (#4903)
* [Textual inversion] Relax loading textual inversion * up
This commit is contained in:
committed by
GitHub
parent
6c314ad0ce
commit
dc3e0ca59b
@@ -663,6 +663,8 @@ class TextualInversionLoaderMixin:
|
||||
self,
|
||||
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
||||
token: Optional[Union[str, List[str]]] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizer] = None,
|
||||
text_encoder: Optional[PreTrainedModel] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -684,6 +686,11 @@ class TextualInversionLoaderMixin:
|
||||
token (`str` or `List[str]`, *optional*):
|
||||
Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
|
||||
list, then `token` must also be a list of equal length.
|
||||
text_encoder ([`~transformers.CLIPTextModel`], *optional*):
|
||||
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
|
||||
If not specified, function will take self.tokenizer.
|
||||
tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
|
||||
A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
|
||||
weight_name (`str`, *optional*):
|
||||
Name of a custom weight file. This should be used when:
|
||||
|
||||
@@ -757,15 +764,18 @@ class TextualInversionLoaderMixin:
|
||||
```
|
||||
|
||||
"""
|
||||
if not hasattr(self, "tokenizer") or not isinstance(self.tokenizer, PreTrainedTokenizer):
|
||||
tokenizer = tokenizer or getattr(self, "tokenizer", None)
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.tokenizer` of type `PreTrainedTokenizer` for calling"
|
||||
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
)
|
||||
|
||||
if not hasattr(self, "text_encoder") or not isinstance(self.text_encoder, PreTrainedModel):
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.text_encoder` of type `PreTrainedModel` for calling"
|
||||
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
)
|
||||
|
||||
@@ -830,7 +840,7 @@ class TextualInversionLoaderMixin:
|
||||
token_ids_and_embeddings = []
|
||||
|
||||
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
|
||||
if not isinstance(pretrained_model_name_or_path, dict):
|
||||
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
|
||||
# 1. Load textual inversion file
|
||||
model_file = None
|
||||
# Let's first try to load .safetensors weights
|
||||
@@ -897,10 +907,10 @@ class TextualInversionLoaderMixin:
|
||||
else:
|
||||
token = loaded_token
|
||||
|
||||
embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device)
|
||||
embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device)
|
||||
|
||||
# 3. Make sure we don't mess up the tokenizer or text encoder
|
||||
vocab = self.tokenizer.get_vocab()
|
||||
vocab = tokenizer.get_vocab()
|
||||
if token in vocab:
|
||||
raise ValueError(
|
||||
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
||||
@@ -908,7 +918,7 @@ class TextualInversionLoaderMixin:
|
||||
elif f"{token}_1" in vocab:
|
||||
multi_vector_tokens = [token]
|
||||
i = 1
|
||||
while f"{token}_{i}" in self.tokenizer.added_tokens_encoder:
|
||||
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
||||
multi_vector_tokens.append(f"{token}_{i}")
|
||||
i += 1
|
||||
|
||||
@@ -926,16 +936,16 @@ class TextualInversionLoaderMixin:
|
||||
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
||||
|
||||
# add tokens and get ids
|
||||
self.tokenizer.add_tokens(tokens)
|
||||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
tokenizer.add_tokens(tokens)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
token_ids_and_embeddings += zip(token_ids, embeddings)
|
||||
|
||||
logger.info(f"Loaded textual inversion embedding for {token}.")
|
||||
|
||||
# resize token embeddings and set all new embeddings
|
||||
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
for token_id, embedding in token_ids_and_embeddings:
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
|
||||
# offload back
|
||||
if is_model_cpu_offload:
|
||||
|
||||
@@ -84,7 +84,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
|
||||
class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL.
|
||||
|
||||
|
||||
@@ -84,7 +84,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
|
||||
class StableDiffusionXLImg2ImgPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL.
|
||||
|
||||
|
||||
@@ -230,7 +230,9 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin):
|
||||
class StableDiffusionXLInpaintPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL.
|
||||
|
||||
|
||||
@@ -62,7 +62,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
|
||||
class StableDiffusionXLInstructPix2PixPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL.
|
||||
|
||||
|
||||
@@ -123,7 +123,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
|
||||
class StableDiffusionXLAdapterPipeline(
|
||||
DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
|
||||
https://arxiv.org/abs/2302.08453
|
||||
|
||||
Reference in New Issue
Block a user