1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fast tok update (#13036)

* v5 tok update

* ruff

* keep pre v5 slow code path

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
Ita Zaporozhets
2026-01-28 12:43:04 +01:00
committed by GitHub
parent ef913010d4
commit 2ac39ba664

View File

@@ -11,11 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Dict, List, Optional, Union
import safetensors
import torch
from huggingface_hub.utils import validate_hf_hub_args
from tokenizers import Tokenizer as TokenizerFast
from torch import nn
from ..models.modeling_utils import load_state_dict
@@ -547,23 +549,39 @@ class TextualInversionLoaderMixin:
else:
last_special_token_id = added_token_id
# Delete from tokenizer
for token_id, token_to_remove in zip(token_ids, tokens):
del tokenizer._added_tokens_decoder[token_id]
del tokenizer._added_tokens_encoder[token_to_remove]
# Make all token ids sequential in tokenizer
key_id = 1
for token_id in tokenizer.added_tokens_decoder:
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
token = tokenizer._added_tokens_decoder[token_id]
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
# Fast tokenizers (v5+)
if hasattr(tokenizer, "_tokenizer"):
# Fast tokenizers: serialize, filter tokens, reload
tokenizer_json = json.loads(tokenizer._tokenizer.to_str())
new_id = last_special_token_id + 1
filtered = []
for tok in tokenizer_json.get("added_tokens", []):
if tok.get("content") in set(tokens):
continue
if not tok.get("special", False):
tok["id"] = new_id
new_id += 1
filtered.append(tok)
tokenizer_json["added_tokens"] = filtered
tokenizer._tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
else:
# Slow tokenizers
for token_id, token_to_remove in zip(token_ids, tokens):
del tokenizer._added_tokens_decoder[token_id]
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
key_id += 1
tokenizer._update_trie()
# set correct total vocab size after removing tokens
tokenizer._update_total_vocab_size()
del tokenizer._added_tokens_encoder[token_to_remove]
key_id = 1
for token_id in list(tokenizer.added_tokens_decoder.keys()):
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
token = tokenizer._added_tokens_decoder[token_id]
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
del tokenizer._added_tokens_decoder[token_id]
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
key_id += 1
if hasattr(tokenizer, "_update_trie"):
tokenizer._update_trie()
if hasattr(tokenizer, "_update_total_vocab_size"):
tokenizer._update_total_vocab_size()
# Delete from text encoder
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim