mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix bug in Textual Inversion Unloading (#9304)
* Update textual_inversion.py * add unload test * add comment * fix style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -561,6 +561,8 @@ class TextualInversionLoaderMixin:
|
||||
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
|
||||
key_id += 1
|
||||
tokenizer._update_trie()
|
||||
# set correct total vocab size after removing tokens
|
||||
tokenizer._update_total_vocab_size()
|
||||
|
||||
# Delete from text encoder
|
||||
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
|
||||
|
||||
@@ -947,6 +947,27 @@ class DownloadTests(unittest.TestCase):
|
||||
emb1[num_tokens + 1].sum().item() == emb2[num_tokens + 1].sum().item() == emb3[num_tokens + 1].sum().item()
|
||||
)
|
||||
|
||||
def test_textual_inversion_unload(self):
|
||||
pipe1 = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
|
||||
)
|
||||
pipe1 = pipe1.to(torch_device)
|
||||
orig_tokenizer_size = len(pipe1.tokenizer)
|
||||
orig_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
|
||||
|
||||
token = "<*>"
|
||||
ten = torch.ones((32,))
|
||||
pipe1.load_textual_inversion(ten, token=token)
|
||||
pipe1.unload_textual_inversion()
|
||||
pipe1.load_textual_inversion(ten, token=token)
|
||||
pipe1.unload_textual_inversion()
|
||||
|
||||
final_tokenizer_size = len(pipe1.tokenizer)
|
||||
final_emb_size = len(pipe1.text_encoder.get_input_embeddings().weight)
|
||||
# both should be restored to original size
|
||||
assert final_tokenizer_size == orig_tokenizer_size
|
||||
assert final_emb_size == orig_emb_size
|
||||
|
||||
def test_download_ignore_files(self):
|
||||
# Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
||||
Reference in New Issue
Block a user