mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Textual inversion] Refactor textual inversion to make it cleaner (#5076)
* [Textual inversion] Clean loading * [Textual inversion] Clean loading * [Textual inversion] Clean up * [Textual inversion] Clean up * [Textual inversion] Clean up * [Textual inversion] Clean up
This commit is contained in:
committed by
GitHub
parent
bfc606301f
commit
7b39f43c06
@@ -623,6 +623,81 @@ class UNet2DConditionLoadersMixin:
|
||||
module._unfuse_lora()
|
||||
|
||||
|
||||
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "text_inversion",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path in pretrained_model_name_or_paths:
|
||||
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
|
||||
# 3.1. Load textual inversion file
|
||||
model_file = None
|
||||
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except Exception as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
|
||||
model_file = None
|
||||
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
return state_dicts
|
||||
|
||||
|
||||
class TextualInversionLoaderMixin:
|
||||
r"""
|
||||
Load textual inversion tokens and embeddings to the tokenizer and text encoder.
|
||||
@@ -685,6 +760,97 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
return prompt
|
||||
|
||||
def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens):
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
)
|
||||
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
)
|
||||
|
||||
if len(pretrained_model_name_or_paths) != len(tokens):
|
||||
raise ValueError(
|
||||
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
|
||||
f"Make sure both lists have the same length."
|
||||
)
|
||||
|
||||
valid_tokens = [t for t in tokens if t is not None]
|
||||
if len(set(valid_tokens)) < len(valid_tokens):
|
||||
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
|
||||
all_tokens = []
|
||||
all_embeddings = []
|
||||
for state_dict, token in zip(state_dicts, tokens):
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
if token is None:
|
||||
raise ValueError(
|
||||
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
||||
)
|
||||
loaded_token = token
|
||||
embedding = state_dict
|
||||
elif len(state_dict) == 1:
|
||||
# diffusers
|
||||
loaded_token, embedding = next(iter(state_dict.items()))
|
||||
elif "string_to_param" in state_dict:
|
||||
# A1111
|
||||
loaded_token = state_dict["name"]
|
||||
embedding = state_dict["string_to_param"]["*"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Loaded state dictonary is incorrect: {state_dict}. \n\n"
|
||||
"Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
|
||||
" input key."
|
||||
)
|
||||
|
||||
if token is not None and loaded_token != token:
|
||||
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
||||
else:
|
||||
token = loaded_token
|
||||
|
||||
if token in tokenizer.get_vocab():
|
||||
raise ValueError(
|
||||
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
||||
)
|
||||
|
||||
all_tokens.append(token)
|
||||
all_embeddings.append(embedding)
|
||||
|
||||
return all_tokens, all_embeddings
|
||||
|
||||
@staticmethod
|
||||
def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer):
|
||||
all_tokens = []
|
||||
all_embeddings = []
|
||||
|
||||
for embedding, token in zip(embeddings, tokens):
|
||||
if f"{token}_1" in tokenizer.get_vocab():
|
||||
multi_vector_tokens = [token]
|
||||
i = 1
|
||||
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
||||
multi_vector_tokens.append(f"{token}_{i}")
|
||||
i += 1
|
||||
|
||||
raise ValueError(
|
||||
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
||||
)
|
||||
|
||||
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
||||
if is_multi_vector:
|
||||
all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
||||
all_embeddings += [e for e in embedding] # noqa: C416
|
||||
else:
|
||||
all_tokens += [token]
|
||||
all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
||||
|
||||
return all_tokens, all_embeddings
|
||||
|
||||
def load_textual_inversion(
|
||||
self,
|
||||
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
||||
@@ -790,25 +956,44 @@ class TextualInversionLoaderMixin:
|
||||
```
|
||||
|
||||
"""
|
||||
# 1. Set correct tokenizer and text encoder
|
||||
tokenizer = tokenizer or getattr(self, "tokenizer", None)
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
|
||||
if tokenizer is None:
|
||||
# 2. Normalize inputs
|
||||
pretrained_model_name_or_paths = (
|
||||
[pretrained_model_name_or_path]
|
||||
if not isinstance(pretrained_model_name_or_path, list)
|
||||
else pretrained_model_name_or_path
|
||||
)
|
||||
tokens = len(pretrained_model_name_or_paths) * [token] if (isinstance(token, str) or token is None) else token
|
||||
|
||||
# 3. Check inputs
|
||||
self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
|
||||
|
||||
# 4. Load state dicts of textual embeddings
|
||||
state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
|
||||
|
||||
# 4. Retrieve tokens and embeddings
|
||||
tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
|
||||
|
||||
# 5. Extend tokens and embeddings for multi vector
|
||||
tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
|
||||
|
||||
# 6. Make sure all embeddings have the correct size
|
||||
expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
|
||||
if any(expected_emb_dim != emb.shape[-1] for emb in embeddings):
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
"Loaded embeddings are of incorrect shape. Expected each textual inversion embedding "
|
||||
"to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} "
|
||||
)
|
||||
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
)
|
||||
# 7. Now we can be sure that loading the embedding matrix works
|
||||
# < Unsafe code:
|
||||
|
||||
# Remove any existing hooks.
|
||||
# 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
recursive = False
|
||||
for _, component in self.components.items():
|
||||
if isinstance(component, nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
@@ -817,168 +1002,34 @@ class TextualInversionLoaderMixin:
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
recursive = is_sequential_cpu_offload
|
||||
remove_hook_from_module(component, recurse=recursive)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
# 7.2 save expected device and dtype
|
||||
device = text_encoder.device
|
||||
dtype = text_encoder.dtype
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "text_inversion",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path, list):
|
||||
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
|
||||
else:
|
||||
pretrained_model_name_or_paths = pretrained_model_name_or_path
|
||||
|
||||
if isinstance(token, str):
|
||||
tokens = [token]
|
||||
elif token is None:
|
||||
tokens = [None] * len(pretrained_model_name_or_paths)
|
||||
else:
|
||||
tokens = token
|
||||
|
||||
if len(pretrained_model_name_or_paths) != len(tokens):
|
||||
raise ValueError(
|
||||
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
|
||||
f"Make sure both lists have the same length."
|
||||
)
|
||||
|
||||
valid_tokens = [t for t in tokens if t is not None]
|
||||
if len(set(valid_tokens)) < len(valid_tokens):
|
||||
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
|
||||
|
||||
token_ids_and_embeddings = []
|
||||
|
||||
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
|
||||
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
|
||||
# 1. Load textual inversion file
|
||||
model_file = None
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except Exception as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
|
||||
model_file = None
|
||||
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path
|
||||
|
||||
# 2. Load token and embedding correcly from file
|
||||
loaded_token = None
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
if token is None:
|
||||
raise ValueError(
|
||||
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
||||
)
|
||||
embedding = state_dict
|
||||
elif len(state_dict) == 1:
|
||||
# diffusers
|
||||
loaded_token, embedding = next(iter(state_dict.items()))
|
||||
elif "string_to_param" in state_dict:
|
||||
# A1111
|
||||
loaded_token = state_dict["name"]
|
||||
embedding = state_dict["string_to_param"]["*"]
|
||||
|
||||
if token is not None and loaded_token != token:
|
||||
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
||||
else:
|
||||
token = loaded_token
|
||||
|
||||
embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device)
|
||||
|
||||
# 3. Make sure we don't mess up the tokenizer or text encoder
|
||||
vocab = tokenizer.get_vocab()
|
||||
if token in vocab:
|
||||
raise ValueError(
|
||||
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
||||
)
|
||||
elif f"{token}_1" in vocab:
|
||||
multi_vector_tokens = [token]
|
||||
i = 1
|
||||
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
||||
multi_vector_tokens.append(f"{token}_{i}")
|
||||
i += 1
|
||||
|
||||
raise ValueError(
|
||||
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
||||
)
|
||||
|
||||
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
||||
|
||||
if is_multi_vector:
|
||||
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
||||
embeddings = [e for e in embedding] # noqa: C416
|
||||
else:
|
||||
tokens = [token]
|
||||
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
||||
# 7.3 Increase token embedding matrix
|
||||
text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens))
|
||||
input_embeddings = text_encoder.get_input_embeddings().weight
|
||||
|
||||
# 7.4 Load token and embedding
|
||||
for token, embedding in zip(tokens, embeddings):
|
||||
# add tokens and get ids
|
||||
tokenizer.add_tokens(tokens)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
token_ids_and_embeddings += zip(token_ids, embeddings)
|
||||
|
||||
tokenizer.add_tokens(token)
|
||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||
input_embeddings.data[token_id] = embedding
|
||||
logger.info(f"Loaded textual inversion embedding for {token}.")
|
||||
|
||||
# resize token embeddings and set all new embeddings
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
for token_id, embedding in token_ids_and_embeddings:
|
||||
text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
input_embeddings.to(dtype=dtype, device=device)
|
||||
|
||||
# offload back
|
||||
# 7.5 Offload the model again
|
||||
if is_model_cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
self.enable_sequential_cpu_offload()
|
||||
|
||||
# / Unsafe Code >
|
||||
|
||||
|
||||
class LoraLoaderMixin:
|
||||
r"""
|
||||
@@ -2598,7 +2649,6 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
recursive = False
|
||||
for _, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
@@ -2607,8 +2657,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
recursive = is_sequential_cpu_offload
|
||||
remove_hook_from_module(component, recurse=recursive)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
|
||||
Reference in New Issue
Block a user