1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix bug in Textual Inversion Unloading (#9304)

* Update textual_inversion.py

* add unload test

* add comment

* fix style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
bonlime
2024-10-19 17:37:32 +05:00
committed by GitHub
parent 2541d141d5
commit 5d3e7bdaaa
2 changed files with 23 additions and 0 deletions

View File

@@ -561,6 +561,8 @@ class TextualInversionLoaderMixin:
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
key_id += 1
tokenizer._update_trie()
# set correct total vocab size after removing tokens
tokenizer._update_total_vocab_size()
# Delete from text encoder
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim

View File

@@ -947,6 +947,27 @@ class DownloadTests(unittest.TestCase):
emb1[num_tokens + 1].sum().item() == emb2[num_tokens + 1].sum().item() == emb3[num_tokens + 1].sum().item()
)
def test_textual_inversion_unload(self):
pipe1 = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
pipe1 = pipe1.to(torch_device)
orig_tokenizer_size = len(pipe1.tokenizer)
orig_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
token = "<*>"
ten = torch.ones((32,))
pipe1.load_textual_inversion(ten, token=token)
pipe1.unload_textual_inversion()
pipe1.load_textual_inversion(ten, token=token)
pipe1.unload_textual_inversion()
final_tokenizer_size = len(pipe1.tokenizer)
final_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
# both should be restored to original size
assert final_tokenizer_size == orig_tokenizer_size
assert final_emb_size == orig_emb_size
def test_download_ignore_files(self):
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
with tempfile.TemporaryDirectory() as tmpdirname: