mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Feat] Enable State Dict For Textual Inversion Loader (#3439)
* enable state dict for textual inversion loader * Empty-Commit | restart CI * Empty-Commit | restart CI * Empty-Commit | restart CI * Empty-Commit | restart CI * add tests * fix tests * fix tests * fix tests --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -470,7 +470,7 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
def load_textual_inversion(
|
||||
self,
|
||||
pretrained_model_name_or_path: Union[str, List[str]],
|
||||
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
||||
token: Optional[Union[str, List[str]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -485,7 +485,7 @@ class TextualInversionLoaderMixin:
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]`):
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
@@ -494,6 +494,8 @@ class TextualInversionLoaderMixin:
|
||||
- A path to a *directory* containing textual inversion weights, e.g.
|
||||
`./my_text_inversion_directory/`.
|
||||
- A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`.
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
Or a list of those elements.
|
||||
token (`str` or `List[str]`, *optional*):
|
||||
@@ -618,7 +620,7 @@ class TextualInversionLoaderMixin:
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if isinstance(pretrained_model_name_or_path, str):
|
||||
if not isinstance(pretrained_model_name_or_path, list):
|
||||
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
|
||||
else:
|
||||
pretrained_model_name_or_paths = pretrained_model_name_or_path
|
||||
@@ -643,16 +645,38 @@ class TextualInversionLoaderMixin:
|
||||
token_ids_and_embeddings = []
|
||||
|
||||
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
|
||||
# 1. Load textual inversion file
|
||||
model_file = None
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
if not isinstance(pretrained_model_name_or_path, dict):
|
||||
# 1. Load textual inversion file
|
||||
model_file = None
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except Exception as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
|
||||
model_file = None
|
||||
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
@@ -663,28 +687,9 @@ class TextualInversionLoaderMixin:
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except Exception as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
|
||||
model_file = None
|
||||
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path
|
||||
|
||||
# 2. Load token and embedding correcly from file
|
||||
loaded_token = None
|
||||
|
||||
@@ -663,6 +663,65 @@ class DownloadTests(unittest.TestCase):
|
||||
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
|
||||
assert out.shape == (1, 128, 128, 3)
|
||||
|
||||
# single token state dict load
|
||||
ten = {"<x>": torch.ones((32,))}
|
||||
pipe.load_textual_inversion(ten)
|
||||
|
||||
token = pipe.tokenizer.convert_tokens_to_ids("<x>")
|
||||
assert token == num_tokens + 10, "Added token must be at spot `num_tokens`"
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32
|
||||
assert pipe._maybe_convert_prompt("<x>", pipe.tokenizer) == "<x>"
|
||||
|
||||
prompt = "hey <x>"
|
||||
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
|
||||
assert out.shape == (1, 128, 128, 3)
|
||||
|
||||
# multi embedding state dict load
|
||||
ten1 = {"<xxxxx>": torch.ones((32,))}
|
||||
ten2 = {"<xxxxxx>": 2 * torch.ones((1, 32))}
|
||||
|
||||
pipe.load_textual_inversion([ten1, ten2])
|
||||
|
||||
token = pipe.tokenizer.convert_tokens_to_ids("<xxxxx>")
|
||||
assert token == num_tokens + 11, "Added token must be at spot `num_tokens`"
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32
|
||||
assert pipe._maybe_convert_prompt("<xxxxx>", pipe.tokenizer) == "<xxxxx>"
|
||||
|
||||
token = pipe.tokenizer.convert_tokens_to_ids("<xxxxxx>")
|
||||
assert token == num_tokens + 12, "Added token must be at spot `num_tokens`"
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64
|
||||
assert pipe._maybe_convert_prompt("<xxxxxx>", pipe.tokenizer) == "<xxxxxx>"
|
||||
|
||||
prompt = "hey <xxxxx> <xxxxxx>"
|
||||
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
|
||||
assert out.shape == (1, 128, 128, 3)
|
||||
|
||||
# auto1111 multi-token state dict load
|
||||
ten = {
|
||||
"string_to_param": {
|
||||
"*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))])
|
||||
},
|
||||
"name": "<xxxx>",
|
||||
}
|
||||
|
||||
pipe.load_textual_inversion(ten)
|
||||
|
||||
token = pipe.tokenizer.convert_tokens_to_ids("<xxxx>")
|
||||
token_1 = pipe.tokenizer.convert_tokens_to_ids("<xxxx>_1")
|
||||
token_2 = pipe.tokenizer.convert_tokens_to_ids("<xxxx>_2")
|
||||
|
||||
assert token == num_tokens + 13, "Added token must be at spot `num_tokens`"
|
||||
assert token_1 == num_tokens + 14, "Added token must be at spot `num_tokens`"
|
||||
assert token_2 == num_tokens + 15, "Added token must be at spot `num_tokens`"
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128
|
||||
assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160
|
||||
assert pipe._maybe_convert_prompt("<xxxx>", pipe.tokenizer) == "<xxxx> <xxxx>_1 <xxxx>_2"
|
||||
|
||||
prompt = "hey <xxxx>"
|
||||
out = pipe(prompt, num_inference_steps=1, output_type="numpy").images
|
||||
assert out.shape == (1, 128, 128, 3)
|
||||
|
||||
def test_download_ignore_files(self):
|
||||
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
||||
Reference in New Issue
Block a user