mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-29 05:02:09 +03:00
201 lines
10 KiB
Python
201 lines
10 KiB
Python
import os
|
|
import typing
|
|
import torch
|
|
from compel import ReturnedEmbeddingsType
|
|
from compel.embeddings_provider import BaseTextualInversionManager, EmbeddingsProvider
|
|
from modules import shared, prompt_parser, devices
|
|
|
|
|
|
debug = shared.log.info if os.environ.get('SD_PROMPT_DEBUG', None) is not None else lambda *args, **kwargs: None
|
|
|
|
|
|
CLIP_SKIP_MAPPING = {
|
|
None: ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED,
|
|
1: ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED,
|
|
2: ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED,
|
|
}
|
|
|
|
|
|
# from https://github.com/damian0815/compel/blob/main/src/compel/diffusers_textual_inversion_manager.py
|
|
class DiffusersTextualInversionManager(BaseTextualInversionManager):
|
|
def __init__(self, pipe, tokenizer):
|
|
self.pipe = pipe
|
|
self.tokenizer = tokenizer
|
|
if hasattr(self.pipe, 'embedding_db'):
|
|
self.pipe.embedding_db.embeddings_used.clear()
|
|
|
|
# from https://github.com/huggingface/diffusers/blob/705c592ea98ba4e288d837b9cba2767623c78603/src/diffusers/loaders.py#L599
|
|
def maybe_convert_prompt(self, prompt: typing.Union[str, typing.List[str]], tokenizer="PreTrainedTokenizer"):
|
|
prompts = [prompt] if not isinstance(prompt, typing.List) else prompt
|
|
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
|
|
if not isinstance(prompt, typing.List):
|
|
return prompts[0]
|
|
return prompts
|
|
|
|
def _maybe_convert_prompt(self, prompt: str, tokenizer="PreTrainedTokenizer"):
|
|
tokens = tokenizer.tokenize(prompt)
|
|
unique_tokens = set(tokens)
|
|
for token in unique_tokens:
|
|
if token in tokenizer.added_tokens_encoder:
|
|
if hasattr(self.pipe, 'embedding_db'):
|
|
self.pipe.embedding_db.embeddings_used.append(token)
|
|
replacement = token
|
|
i = 1
|
|
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
|
replacement += f" {token}_{i}"
|
|
i += 1
|
|
prompt = prompt.replace(token, replacement)
|
|
if hasattr(self.pipe, 'embedding_db'):
|
|
self.pipe.embedding_db.embeddings_used = list(set(self.pipe.embedding_db.embeddings_used))
|
|
debug(f'Prompt: convert={prompt}')
|
|
return prompt
|
|
|
|
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: typing.List[int]) -> typing.List[int]:
|
|
if len(token_ids) == 0:
|
|
return token_ids
|
|
prompt = self.pipe.tokenizer.decode(token_ids)
|
|
prompt = self.maybe_convert_prompt(prompt, self.pipe.tokenizer)
|
|
debug(f'Prompt: expand={prompt}')
|
|
return self.pipe.tokenizer.encode(prompt, add_special_tokens=False)
|
|
|
|
|
|
def encode_prompts(pipe, prompts: list, negative_prompts: list, clip_skip: typing.Optional[int] = None):
|
|
if 'StableDiffusion' not in pipe.__class__.__name__ and 'DemoFusion':
|
|
shared.log.warning(f"Prompt parser not supported: {pipe.__class__.__name__}")
|
|
return None, None, None, None
|
|
else:
|
|
prompt_embeds = []
|
|
positive_pooleds = []
|
|
negative_embeds = []
|
|
negative_pooleds = []
|
|
for i in range(len(prompts)):
|
|
prompt_embed, positive_pooled, negative_embed, negative_pooled = get_weighted_text_embeddings(pipe, prompts[i], negative_prompts[i], clip_skip)
|
|
prompt_embeds.append(prompt_embed)
|
|
positive_pooleds.append(positive_pooled)
|
|
negative_embeds.append(negative_embed)
|
|
negative_pooleds.append(negative_pooled)
|
|
|
|
if prompt_embeds is not None:
|
|
prompt_embeds = torch.cat(prompt_embeds, dim=0)
|
|
if negative_embeds is not None:
|
|
negative_embeds = torch.cat(negative_embeds, dim=0)
|
|
if positive_pooleds is not None and shared.sd_model_type == "sdxl":
|
|
positive_pooleds = torch.cat(positive_pooleds, dim=0)
|
|
if negative_pooleds is not None and shared.sd_model_type == "sdxl":
|
|
negative_pooleds = torch.cat(negative_pooleds, dim=0)
|
|
return prompt_embeds, positive_pooleds, negative_embeds, negative_pooleds
|
|
|
|
|
|
def get_prompts_with_weights(prompt: str):
|
|
manager = DiffusersTextualInversionManager(shared.sd_model, shared.sd_model.tokenizer or shared.sd_model.tokenizer_2)
|
|
prompt = manager.maybe_convert_prompt(prompt, shared.sd_model.tokenizer or shared.sd_model.tokenizer_2)
|
|
texts_and_weights = prompt_parser.parse_prompt_attention(prompt)
|
|
texts = [t for t, w in texts_and_weights]
|
|
text_weights = [w for t, w in texts_and_weights]
|
|
debug(f'Prompt: weights={texts_and_weights}')
|
|
return texts, text_weights
|
|
|
|
|
|
def prepare_embedding_providers(pipe, clip_skip):
|
|
device = pipe.device if str(pipe.device) != 'meta' else devices.device
|
|
embeddings_providers = []
|
|
if 'XL' in pipe.__class__.__name__:
|
|
embedding_type = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
|
|
else:
|
|
if clip_skip > 2:
|
|
shared.log.warning(f"Prompt parser unsupported: clip_skip={clip_skip}")
|
|
clip_skip = 2
|
|
embedding_type = CLIP_SKIP_MAPPING[clip_skip]
|
|
if getattr(pipe, "tokenizer", None) is not None and getattr(pipe, "text_encoder", None) is not None:
|
|
embedding = EmbeddingsProvider(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, truncate=False, returned_embeddings_type=embedding_type, device=device)
|
|
embeddings_providers.append(embedding)
|
|
if getattr(pipe, "tokenizer_2", None) is not None and getattr(pipe, "text_encoder_2", None) is not None:
|
|
embedding = EmbeddingsProvider(tokenizer=pipe.tokenizer_2, text_encoder=pipe.text_encoder_2, truncate=False, returned_embeddings_type=embedding_type, device=device)
|
|
embeddings_providers.append(embedding)
|
|
return embeddings_providers
|
|
|
|
|
|
def pad_to_same_length(pipe, embeds):
|
|
device = pipe.device if str(pipe.device) != 'meta' else devices.device
|
|
try: #SDXL
|
|
empty_embed = pipe.encode_prompt("")
|
|
except Exception: #SD1.5
|
|
empty_embed = pipe.encode_prompt("", device, 1, False)
|
|
empty_batched = torch.cat([empty_embed[0].to(embeds[0].device)] * embeds[0].shape[0])
|
|
max_token_count = max([embed.shape[1] for embed in embeds])
|
|
for i, embed in enumerate(embeds):
|
|
while embed.shape[1] < max_token_count:
|
|
embed = torch.cat([embed, empty_batched], dim=1)
|
|
embeds[i] = embed
|
|
return embeds
|
|
|
|
|
|
def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None):
|
|
device = pipe.device if str(pipe.device) != 'meta' else devices.device
|
|
prompt_2 = prompt.split("TE2:")[-1]
|
|
neg_prompt_2 = neg_prompt.split("TE2:")[-1]
|
|
prompt = prompt.split("TE2:")[0]
|
|
neg_prompt = neg_prompt.split("TE2:")[0]
|
|
|
|
ps = [get_prompts_with_weights(p) for p in [prompt, prompt_2]]
|
|
positives = [t for t, w in ps]
|
|
positive_weights = [w for t, w in ps]
|
|
ns = [get_prompts_with_weights(p) for p in [neg_prompt, neg_prompt_2]]
|
|
negatives = [t for t, w in ns]
|
|
negative_weights = [w for t, w in ns]
|
|
if getattr(pipe, "tokenizer_2", None) is not None and getattr(pipe, "tokenizer", None) is None:
|
|
positives.pop(0)
|
|
positive_weights.pop(0)
|
|
negatives.pop(0)
|
|
negative_weights.pop(0)
|
|
|
|
embedding_providers = prepare_embedding_providers(pipe, clip_skip)
|
|
prompt_embeds = []
|
|
negative_prompt_embeds = []
|
|
pooled_prompt_embeds = None
|
|
negative_pooled_prompt_embeds = None
|
|
for i in range(len(embedding_providers)):
|
|
# add BREAK keyword that splits the prompt into multiple fragments
|
|
text = positives[i]
|
|
weights = positive_weights[i]
|
|
text.append('BREAK')
|
|
weights.append(-1)
|
|
provider_embed = []
|
|
while 'BREAK' in text:
|
|
pos = text.index('BREAK')
|
|
debug(f'Prompt: section="{text[:pos]}" len={len(text[:pos])} weights={weights[:pos]}')
|
|
if len(text[:pos]) > 0:
|
|
embed, ptokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[text[:pos]], fragment_weights_batch=[weights[:pos]], device=device, should_return_tokens=True)
|
|
provider_embed.append(embed)
|
|
text = text[pos+1:]
|
|
weights = weights[pos+1:]
|
|
prompt_embeds.append(torch.cat(provider_embed, dim=1))
|
|
# negative prompt has no keywords
|
|
embed, ntokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[negatives[i]], fragment_weights_batch=[negative_weights[i]], device=device, should_return_tokens=True)
|
|
negative_prompt_embeds.append(embed)
|
|
|
|
if prompt_embeds[-1].shape[-1] > 768:
|
|
if shared.opts.diffusers_pooled == "weighted":
|
|
pooled_prompt_embeds = prompt_embeds[-1][
|
|
torch.arange(prompt_embeds[-1].shape[0], device=device),
|
|
(ptokens.to(dtype=torch.int, device=device) == 49407)
|
|
.int()
|
|
.argmax(dim=-1),
|
|
]
|
|
negative_pooled_prompt_embeds = negative_prompt_embeds[-1][
|
|
torch.arange(negative_prompt_embeds[-1].shape[0], device=device),
|
|
(ntokens.to(dtype=torch.int, device=device) == 49407)
|
|
.int()
|
|
.argmax(dim=-1),
|
|
]
|
|
else:
|
|
pooled_prompt_embeds = embedding_providers[-1].get_pooled_embeddings(texts=[prompt_2], device=device) if prompt_embeds[-1].shape[-1] > 768 else None
|
|
negative_pooled_prompt_embeds = embedding_providers[-1].get_pooled_embeddings(texts=[neg_prompt_2], device=device) if negative_prompt_embeds[-1].shape[-1] > 768 else None
|
|
|
|
prompt_embeds = torch.cat(prompt_embeds, dim=-1) if len(prompt_embeds) > 1 else prompt_embeds[0]
|
|
negative_prompt_embeds = torch.cat(negative_prompt_embeds, dim=-1) if len(negative_prompt_embeds) > 1 else negative_prompt_embeds[0]
|
|
debug(f'Prompt: shape={prompt_embeds.shape} negative={negative_prompt_embeds.shape}')
|
|
if prompt_embeds.shape[1] != negative_prompt_embeds.shape[1]:
|
|
[prompt_embeds, negative_prompt_embeds] = pad_to_same_length(pipe, [prompt_embeds, negative_prompt_embeds])
|
|
return prompt_embeds, pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds
|