mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fast tok update (#13036)
* v5 tok update * ruff * keep pre v5 slow code path * Apply style fixes --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -11,11 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from tokenizers import Tokenizer as TokenizerFast
|
||||
from torch import nn
|
||||
|
||||
from ..models.modeling_utils import load_state_dict
|
||||
@@ -547,23 +549,39 @@ class TextualInversionLoaderMixin:
|
||||
else:
|
||||
last_special_token_id = added_token_id
|
||||
|
||||
# Delete from tokenizer
|
||||
for token_id, token_to_remove in zip(token_ids, tokens):
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
del tokenizer._added_tokens_encoder[token_to_remove]
|
||||
|
||||
# Make all token ids sequential in tokenizer
|
||||
key_id = 1
|
||||
for token_id in tokenizer.added_tokens_decoder:
|
||||
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
|
||||
token = tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
|
||||
# Fast tokenizers (v5+)
|
||||
if hasattr(tokenizer, "_tokenizer"):
|
||||
# Fast tokenizers: serialize, filter tokens, reload
|
||||
tokenizer_json = json.loads(tokenizer._tokenizer.to_str())
|
||||
new_id = last_special_token_id + 1
|
||||
filtered = []
|
||||
for tok in tokenizer_json.get("added_tokens", []):
|
||||
if tok.get("content") in set(tokens):
|
||||
continue
|
||||
if not tok.get("special", False):
|
||||
tok["id"] = new_id
|
||||
new_id += 1
|
||||
filtered.append(tok)
|
||||
tokenizer_json["added_tokens"] = filtered
|
||||
tokenizer._tokenizer = TokenizerFast.from_str(json.dumps(tokenizer_json))
|
||||
else:
|
||||
# Slow tokenizers
|
||||
for token_id, token_to_remove in zip(token_ids, tokens):
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
|
||||
key_id += 1
|
||||
tokenizer._update_trie()
|
||||
# set correct total vocab size after removing tokens
|
||||
tokenizer._update_total_vocab_size()
|
||||
del tokenizer._added_tokens_encoder[token_to_remove]
|
||||
|
||||
key_id = 1
|
||||
for token_id in list(tokenizer.added_tokens_decoder.keys()):
|
||||
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
|
||||
token = tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
|
||||
key_id += 1
|
||||
if hasattr(tokenizer, "_update_trie"):
|
||||
tokenizer._update_trie()
|
||||
if hasattr(tokenizer, "_update_total_vocab_size"):
|
||||
tokenizer._update_total_vocab_size()
|
||||
|
||||
# Delete from text encoder
|
||||
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
|
||||
|
||||
Reference in New Issue
Block a user