mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
For FLUX2 models using Mistral3/PixtralProcessor, detect and use the underlying text tokenizer instead of calling the processor directly, which incorrectly tries to interpret text prompts as images.
783 lines
40 KiB
Python
783 lines
40 KiB
Python
import os
|
|
import math
|
|
import time
|
|
import typing
|
|
from collections import OrderedDict
|
|
import torch
|
|
from compel.embeddings_provider import BaseTextualInversionManager, EmbeddingsProvider
|
|
from transformers import PreTrainedTokenizer
|
|
from modules import shared, prompt_parser, devices, sd_models
|
|
from modules.prompt_parser_xhinker import get_weighted_text_embeddings_sd15, get_weighted_text_embeddings_sdxl_2p, get_weighted_text_embeddings_sd3, get_weighted_text_embeddings_flux1, get_weighted_text_embeddings_chroma
|
|
|
|
debug_enabled = os.environ.get('SD_PROMPT_DEBUG', None)
|
|
debug = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
|
|
debug('Trace: PROMPT')
|
|
orig_encode_token_ids_to_embeddings = EmbeddingsProvider._encode_token_ids_to_embeddings # pylint: disable=protected-access
|
|
token_dict = None # used by helper get_tokens
|
|
token_type = None # used by helper get_tokens
|
|
cache = OrderedDict()
|
|
last_attention = None
|
|
embedder = None
|
|
|
|
|
|
def prompt_compatible(pipe = None):
|
|
pipe = pipe or shared.sd_model
|
|
if (
|
|
'StableDiffusion' not in pipe.__class__.__name__ and
|
|
'DemoFusion' not in pipe.__class__.__name__ and
|
|
'StableCascade' not in pipe.__class__.__name__ and
|
|
'Flux' not in pipe.__class__.__name__ and
|
|
'Chroma' not in pipe.__class__.__name__ and
|
|
'HiDreamImage' not in pipe.__class__.__name__
|
|
):
|
|
shared.log.warning(f"Prompt parser not supported: {pipe.__class__.__name__}")
|
|
return False
|
|
return True
|
|
|
|
|
|
def prepare_model(pipe = None):
|
|
pipe = pipe or shared.sd_model
|
|
if not hasattr(pipe, "text_encoder") and hasattr(shared.sd_model, "pipe"):
|
|
pipe = pipe.pipe
|
|
if not hasattr(pipe, "text_encoder"):
|
|
return None
|
|
elif hasattr(pipe, "maybe_free_model_hooks"):
|
|
pipe.maybe_free_model_hooks()
|
|
devices.torch_gc()
|
|
return pipe
|
|
|
|
|
|
class PromptEmbedder:
|
|
def __init__(self, prompts, negative_prompts, steps, clip_skip, p):
|
|
t0 = time.time()
|
|
self.prompts = prompts
|
|
self.negative_prompts = negative_prompts
|
|
self.batchsize = len(self.prompts)
|
|
self.attention = last_attention
|
|
self.allsame = False # dont collapse prompts
|
|
# self.allsame = self.compare_prompts() # collapses batched prompts to single prompt if possible
|
|
self.steps = steps
|
|
self.clip_skip = clip_skip
|
|
# All embeds are nested lists, outer list batch length, inner schedule length
|
|
self.prompt_embeds = [[] for _ in range(self.batchsize)]
|
|
self.positive_pooleds = [[] for _ in range(self.batchsize)]
|
|
self.negative_prompt_embeds = [[] for _ in range(self.batchsize)]
|
|
self.negative_pooleds = [[] for _ in range(self.batchsize)]
|
|
self.prompt_attention_masks = [[] for _ in range(self.batchsize)]
|
|
self.negative_prompt_attention_masks = [[] for _ in range(self.batchsize)]
|
|
self.positive_schedule = None
|
|
self.negative_schedule = None
|
|
self.scheduled_prompt = False
|
|
if hasattr(p, 'dummy'):
|
|
return
|
|
earlyout = self.checkcache(p)
|
|
if earlyout:
|
|
return
|
|
self.pipe = prepare_model(p.sd_model)
|
|
if self.pipe is None:
|
|
shared.log.error("Prompt encode: cannot find text encoder in model")
|
|
return
|
|
seen_prompts = {}
|
|
# per prompt in batch
|
|
for batchidx, (prompt, negative_prompt) in enumerate(zip(self.prompts, self.negative_prompts)):
|
|
self.prepare_schedule(prompt, negative_prompt)
|
|
schedule_key = (
|
|
tuple(self.positive_schedule) if self.positive_schedule is not None else None,
|
|
tuple(self.negative_schedule) if self.negative_schedule is not None else None,
|
|
self.scheduled_prompt,
|
|
)
|
|
cache_key = (prompt, negative_prompt, schedule_key)
|
|
cached_idx = seen_prompts.get(cache_key)
|
|
if cached_idx is not None:
|
|
self.clone_embeds(batchidx, cached_idx)
|
|
continue
|
|
if self.scheduled_prompt:
|
|
self.scheduled_encode(self.pipe, batchidx)
|
|
else:
|
|
self.encode(self.pipe, prompt, negative_prompt, batchidx)
|
|
seen_prompts[cache_key] = batchidx
|
|
self.checkcache(p)
|
|
debug(f"Prompt encode: time={(time.time() - t0):.3f}")
|
|
|
|
def checkcache(self, p) -> bool:
|
|
if shared.opts.sd_textencoder_cache_size == 0:
|
|
return False
|
|
if self.scheduled_prompt:
|
|
debug("Prompt cache: scheduled prompt")
|
|
cache.clear()
|
|
return False
|
|
if self.attention != shared.opts.prompt_attention:
|
|
debug(f"Prompt cache: parser={shared.opts.prompt_attention} changed")
|
|
cache.clear()
|
|
return False
|
|
|
|
def flatten(xss):
|
|
return [x for xs in xss for x in xs]
|
|
|
|
# unpack EN data in case of TE LoRA
|
|
en_data = p.network_data
|
|
en_data = [idx.items for item in en_data.values() for idx in item]
|
|
effective_batch = 1 if self.allsame else self.batchsize
|
|
key = str([self.prompts, self.negative_prompts, effective_batch, self.clip_skip, self.steps, en_data])
|
|
item = cache.get(key)
|
|
if not item:
|
|
if not any(flatten(emb) for emb in [self.prompt_embeds,
|
|
self.negative_prompt_embeds,
|
|
self.positive_pooleds,
|
|
self.negative_pooleds,
|
|
self.prompt_attention_masks,
|
|
self.negative_prompt_attention_masks]):
|
|
return False
|
|
else:
|
|
cache[key] = {'prompt_embeds': self.prompt_embeds,
|
|
'negative_prompt_embeds': self.negative_prompt_embeds,
|
|
'positive_pooleds': self.positive_pooleds,
|
|
'negative_pooleds': self.negative_pooleds,
|
|
'prompt_attention_masks': self.prompt_attention_masks,
|
|
'negative_prompt_attention_masks': self.negative_prompt_attention_masks,
|
|
}
|
|
debug(f"Prompt cache: add={key}")
|
|
while len(cache) > int(shared.opts.sd_textencoder_cache_size):
|
|
cache.popitem(last=False)
|
|
return True
|
|
if item:
|
|
self.__dict__.update(cache[key])
|
|
cache.move_to_end(key)
|
|
if self.allsame and len(self.prompt_embeds) < self.batchsize:
|
|
self.prompt_embeds = [self.prompt_embeds[0]] * self.batchsize
|
|
self.positive_pooleds = [self.positive_pooleds[0]] * self.batchsize
|
|
self.negative_prompt_embeds = [self.negative_prompt_embeds[0]] * self.batchsize
|
|
self.negative_pooleds = [self.negative_pooleds[0]] * self.batchsize
|
|
self.prompt_attention_masks = [self.prompt_attention_masks[0]] * self.batchsize
|
|
self.negative_prompt_attention_masks = [self.negative_prompt_attention_masks[0]] * self.batchsize
|
|
debug(f"Prompt cache: get={key}")
|
|
return True
|
|
|
|
def compare_prompts(self):
|
|
same = (self.prompts == [self.prompts[0]] * len(self.prompts) and self.negative_prompts == [self.negative_prompts[0]] * len(self.negative_prompts))
|
|
if same:
|
|
self.prompts = [self.prompts[0]]
|
|
self.negative_prompts = [self.negative_prompts[0]]
|
|
return same
|
|
|
|
def prepare_schedule(self, prompt, negative_prompt):
|
|
self.positive_schedule, scheduled = get_prompt_schedule(prompt, self.steps)
|
|
self.negative_schedule, neg_scheduled = get_prompt_schedule(negative_prompt, self.steps)
|
|
self.scheduled_prompt = scheduled or neg_scheduled
|
|
debug(f"Prompt schedule: positive={self.positive_schedule} negative={self.negative_schedule} scheduled={scheduled}")
|
|
|
|
def scheduled_encode(self, pipe, batchidx):
|
|
prompt_dict = {} # index cache
|
|
for i in range(max(len(self.positive_schedule), len(self.negative_schedule))):
|
|
positive_prompt = self.positive_schedule[i % len(self.positive_schedule)]
|
|
negative_prompt = self.negative_schedule[i % len(self.negative_schedule)]
|
|
# skip repeated scheduled subprompts
|
|
idx = prompt_dict.get(positive_prompt+negative_prompt)
|
|
if idx is not None:
|
|
self.extend_embeds(batchidx, idx)
|
|
continue
|
|
self.encode(pipe, positive_prompt, negative_prompt, batchidx)
|
|
prompt_dict[positive_prompt+negative_prompt] = i
|
|
|
|
def extend_embeds(self, batchidx, idx): # Extends scheduled prompt via index
|
|
if len(self.prompt_embeds[batchidx]) > 0:
|
|
self.prompt_embeds[batchidx].append(self.prompt_embeds[batchidx][idx])
|
|
if len(self.negative_prompt_embeds[batchidx]) > 0:
|
|
self.negative_prompt_embeds[batchidx].append(self.negative_prompt_embeds[batchidx][idx])
|
|
if len(self.positive_pooleds[batchidx]) > 0:
|
|
self.positive_pooleds[batchidx].append(self.positive_pooleds[batchidx][idx])
|
|
if len(self.negative_pooleds[batchidx]) > 0:
|
|
self.negative_pooleds[batchidx].append(self.negative_pooleds[batchidx][idx])
|
|
if len(self.prompt_attention_masks[batchidx]) > 0:
|
|
self.prompt_attention_masks[batchidx].append(self.prompt_attention_masks[batchidx][idx])
|
|
if len(self.negative_prompt_attention_masks[batchidx]) > 0:
|
|
self.negative_prompt_attention_masks[batchidx].append(self.negative_prompt_attention_masks[batchidx][idx])
|
|
|
|
def encode(self, pipe, positive_prompt, negative_prompt, batchidx):
|
|
if positive_prompt is None:
|
|
positive_prompt = ''
|
|
if negative_prompt is None:
|
|
negative_prompt = ''
|
|
global last_attention # pylint: disable=global-statement
|
|
self.attention = shared.opts.prompt_attention
|
|
last_attention = self.attention
|
|
if self.attention == "xhinker":
|
|
(
|
|
prompt_embed,
|
|
positive_pooled,
|
|
prompt_attention_mask,
|
|
negative_embed,
|
|
negative_pooled,
|
|
negative_prompt_attention_mask
|
|
) = get_xhinker_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip)
|
|
else:
|
|
(
|
|
prompt_embed,
|
|
positive_pooled,
|
|
prompt_attention_mask,
|
|
negative_embed,
|
|
negative_pooled,
|
|
negative_prompt_attention_mask
|
|
) = get_weighted_text_embeddings(pipe, positive_prompt, negative_prompt, self.clip_skip)
|
|
def _store(target, value):
|
|
if value is None:
|
|
return
|
|
# scheduled prompts need to keep all slices, unscheduled can overwrite
|
|
if self.scheduled_prompt and len(target[batchidx]) > 0:
|
|
target[batchidx].append(value)
|
|
else:
|
|
target[batchidx] = [value]
|
|
|
|
_store(self.prompt_embeds, prompt_embed)
|
|
_store(self.negative_prompt_embeds, negative_embed)
|
|
_store(self.positive_pooleds, positive_pooled)
|
|
_store(self.negative_pooleds, negative_pooled)
|
|
_store(self.prompt_attention_masks, prompt_attention_mask)
|
|
_store(self.negative_prompt_attention_masks, negative_prompt_attention_mask)
|
|
if debug_enabled:
|
|
get_tokens(pipe, 'positive', positive_prompt)
|
|
get_tokens(pipe, 'negative', negative_prompt)
|
|
|
|
def clone_embeds(self, batchidx, idx):
|
|
def _clone(target):
|
|
if len(target) <= idx:
|
|
return
|
|
src = target[idx]
|
|
if isinstance(src, list):
|
|
target[batchidx] = [item if not isinstance(item, list) else list(item) for item in src]
|
|
else:
|
|
target[batchidx] = src
|
|
|
|
_clone(self.prompt_embeds)
|
|
_clone(self.negative_prompt_embeds)
|
|
_clone(self.positive_pooleds)
|
|
_clone(self.negative_pooleds)
|
|
_clone(self.prompt_attention_masks)
|
|
_clone(self.negative_prompt_attention_masks)
|
|
|
|
def __call__(self, key, step=0):
|
|
batch = getattr(self, key)
|
|
res = []
|
|
try:
|
|
if len(batch) == 0 or len(batch[0]) == 0:
|
|
return None # flux has no negative prompts
|
|
if isinstance(batch[0][0], list) and len(batch[0][0]) == 2 and isinstance(batch[0][0][1], torch.Tensor) and batch[0][0][1].shape[0] == 32:
|
|
# hidream uses a list of t5 + llama prompt embeds: [t5_embeds, llama_embeds]
|
|
# t5_embeds shape: [batch_size, seq_len, dim]
|
|
# llama_embeds shape: [number_of_hidden_states, batch_size, seq_len, dim]
|
|
res2 = []
|
|
for i in range(self.batchsize):
|
|
if len(batch[i]) == 0: # if asking for a null key, ie pooled on SD1.5
|
|
return None
|
|
try:
|
|
res.append(batch[i][step][0])
|
|
res2.append(batch[i][step][1])
|
|
except IndexError:
|
|
# if not scheduled, return default
|
|
res.append(batch[i][0][0])
|
|
res2.append(batch[i][0][1])
|
|
res = [torch.cat(res, dim=0), torch.cat(res2, dim=1)]
|
|
return res
|
|
else:
|
|
for i in range(self.batchsize):
|
|
if len(batch[i]) == 0: # if asking for a null key, ie pooled on SD1.5
|
|
return None
|
|
try:
|
|
res.append(batch[i][step])
|
|
except IndexError:
|
|
res.append(batch[i][0]) # if not scheduled, return default
|
|
if any(res[0].shape[1] != r.shape[1] for r in res):
|
|
res = pad_to_same_length(self.pipe, res)
|
|
return torch.cat(res)
|
|
except Exception as e:
|
|
shared.log.error(f"Prompt encode: {e}")
|
|
return None
|
|
|
|
|
|
def compel_hijack(self, token_ids: torch.Tensor, attention_mask: typing.Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
needs_hidden_states = self.returned_embeddings_type != 1
|
|
text_encoder_output = self.text_encoder(token_ids, attention_mask, output_hidden_states=needs_hidden_states, return_dict=True)
|
|
|
|
if not needs_hidden_states:
|
|
return text_encoder_output.last_hidden_state
|
|
try:
|
|
normalized = self.returned_embeddings_type > 0
|
|
clip_skip = math.floor(abs(self.returned_embeddings_type))
|
|
interpolation = abs(self.returned_embeddings_type) - clip_skip
|
|
except Exception:
|
|
normalized = False
|
|
clip_skip = 1
|
|
interpolation = False
|
|
if interpolation:
|
|
hidden_state = (1 - interpolation) * text_encoder_output.hidden_states[-clip_skip] + interpolation * text_encoder_output.hidden_states[-(clip_skip+1)]
|
|
else:
|
|
hidden_state = text_encoder_output.hidden_states[-clip_skip]
|
|
if normalized:
|
|
hidden_state = self.text_encoder.text_model.final_layer_norm(hidden_state)
|
|
return hidden_state
|
|
|
|
|
|
def sd3_compel_hijack(self, token_ids: torch.Tensor, attention_mask: typing.Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
needs_hidden_states = True
|
|
text_encoder_output = self.text_encoder(token_ids, attention_mask, output_hidden_states=needs_hidden_states, return_dict=True)
|
|
clip_skip = int(self.returned_embeddings_type)
|
|
hidden_state = text_encoder_output.hidden_states[-(clip_skip+1)]
|
|
return hidden_state
|
|
|
|
|
|
def insert_parser_highjack(pipename):
|
|
if "StableDiffusion3" in pipename:
|
|
EmbeddingsProvider._encode_token_ids_to_embeddings = sd3_compel_hijack # pylint: disable=protected-access
|
|
debug("Load SD3 Parser hijack")
|
|
else:
|
|
EmbeddingsProvider._encode_token_ids_to_embeddings = compel_hijack # pylint: disable=protected-access
|
|
debug("Load Standard Parser hijack")
|
|
|
|
|
|
insert_parser_highjack("Initialize")
|
|
|
|
|
|
# from https://github.com/damian0815/compel/blob/main/src/compel/diffusers_textual_inversion_manager.py
|
|
class DiffusersTextualInversionManager(BaseTextualInversionManager):
|
|
def __init__(self, pipe, tokenizer):
|
|
self.pipe = pipe
|
|
self.tokenizer = tokenizer
|
|
if hasattr(self.pipe, 'embedding_db'):
|
|
self.pipe.embedding_db.embeddings_used.clear()
|
|
|
|
# code from
|
|
# https://github.com/huggingface/diffusers/blob/705c592ea98ba4e288d837b9cba2767623c78603/src/diffusers/loaders.py
|
|
def maybe_convert_prompt(self, prompt: typing.Union[str, typing.List[str]], tokenizer: PreTrainedTokenizer):
|
|
prompts = [prompt] if not isinstance(prompt, typing.List) else prompt
|
|
prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
|
|
if not isinstance(prompt, typing.List):
|
|
return prompts[0]
|
|
return prompts
|
|
|
|
def _maybe_convert_prompt(self, prompt: str, tokenizer: PreTrainedTokenizer):
|
|
tokens = tokenizer.tokenize(prompt)
|
|
unique_tokens = set(tokens)
|
|
for token in unique_tokens:
|
|
if token in tokenizer.added_tokens_encoder:
|
|
if hasattr(self.pipe, 'embedding_db'):
|
|
self.pipe.embedding_db.embeddings_used.append(token)
|
|
replacement = token
|
|
i = 1
|
|
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
|
replacement += f" {token}_{i}"
|
|
i += 1
|
|
prompt = prompt.replace(token, replacement)
|
|
if hasattr(self.pipe, 'embedding_db'):
|
|
self.pipe.embedding_db.embeddings_used = list(set(self.pipe.embedding_db.embeddings_used))
|
|
debug(f'Prompt: convert="{prompt}"')
|
|
return prompt
|
|
|
|
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: typing.List[int]) -> typing.List[int]:
|
|
if len(token_ids) == 0:
|
|
return token_ids
|
|
prompt = self.pipe.tokenizer.decode(token_ids)
|
|
prompt = self.maybe_convert_prompt(prompt, self.pipe.tokenizer)
|
|
debug(f'Prompt: expand="{prompt}"')
|
|
return self.pipe.tokenizer.encode(prompt, add_special_tokens=False)
|
|
|
|
|
|
def get_prompt_schedule(prompt, steps):
|
|
temp = []
|
|
schedule = prompt_parser.get_learned_conditioning_prompt_schedules([prompt], steps)[0]
|
|
if all(x == schedule[0] for x in schedule):
|
|
return [schedule[0][1]], False
|
|
for chunk in schedule:
|
|
for s in range(steps):
|
|
if len(temp) < s + 1 <= chunk[0]:
|
|
temp.append(chunk[1])
|
|
return temp, len(schedule) > 1
|
|
|
|
|
|
def get_tokens(pipe, msg, prompt):
|
|
global token_dict, token_type # pylint: disable=global-statement
|
|
token_count = 0
|
|
if shared.sd_loaded and hasattr(pipe, 'tokenizer') and pipe.tokenizer is not None:
|
|
tokenizer = pipe.tokenizer
|
|
# For multi-modal processors (e.g., PixtralProcessor), use the underlying text tokenizer
|
|
if hasattr(tokenizer, 'tokenizer') and tokenizer.tokenizer is not None:
|
|
tokenizer = tokenizer.tokenizer
|
|
prompt = prompt.replace(' BOS ', ' !!!!!!!! ').replace(' EOS ', ' !!!!!!! ')
|
|
debug(f'Prompt tokenizer: type={msg} prompt="{prompt}"')
|
|
if token_dict is None or token_type != shared.sd_model_type:
|
|
token_type = shared.sd_model_type
|
|
fn = getattr(tokenizer, 'name_or_path', '')
|
|
if fn.endswith('tokenizer'):
|
|
fn = os.path.join(fn, 'vocab.json')
|
|
else:
|
|
fn = os.path.join(fn, 'tokenizer', 'vocab.json')
|
|
token_dict = shared.readfile(fn, silent=True)
|
|
added_tokens = getattr(tokenizer, 'added_tokens_decoder', {})
|
|
for k, v in added_tokens.items():
|
|
token_dict[str(v)] = k
|
|
shared.log.debug(f'Tokenizer: words={len(token_dict)} file="{fn}"')
|
|
has_bos_token = getattr(tokenizer, 'bos_token_id', None) is not None
|
|
has_eos_token = getattr(tokenizer, 'eos_token_id', None) is not None
|
|
try:
|
|
ids = tokenizer(prompt)
|
|
ids = getattr(ids, 'input_ids', [])
|
|
except Exception:
|
|
ids = []
|
|
if has_bos_token and has_eos_token:
|
|
for i in range(len(ids)):
|
|
if ids[i] == 21622:
|
|
ids[i] = tokenizer.bos_token_id
|
|
elif ids[i] == 15203:
|
|
ids[i] = tokenizer.eos_token_id
|
|
tokens = []
|
|
for i in ids:
|
|
try:
|
|
key = list(token_dict.keys())[list(token_dict.values()).index(i)]
|
|
tokens.append(key)
|
|
except Exception:
|
|
tokens.append(f'UNK_{i}')
|
|
token_count = len(ids) - int(has_bos_token) - int(has_eos_token)
|
|
debug(f'Prompt tokenizer: type={msg} tokens={token_count} tokens={tokens} ids={ids}')
|
|
return token_count
|
|
|
|
|
|
def normalize_prompt(pairs: list):
|
|
num_words = 0
|
|
total_weight = 0
|
|
for section in pairs:
|
|
words = len(section[0].split())
|
|
if section[1] == -1: # control tokens
|
|
continue
|
|
num_words += words
|
|
total_weight += section[1] * words
|
|
avg_weight = round(100 * total_weight / num_words) / 100 if num_words > 0 else 1
|
|
debug(f'Prompt stats: words={num_words} weight={avg_weight}')
|
|
for section in pairs:
|
|
section[1] = section[1] / avg_weight if section[1] != -1 else -1 # skip control tokens
|
|
debug(f'Prompt normalized: {pairs}')
|
|
return pairs
|
|
|
|
|
|
def get_prompts_with_weights(pipe, prompt: str):
|
|
t0 = time.time()
|
|
manager = DiffusersTextualInversionManager(pipe, pipe.tokenizer or pipe.tokenizer_2)
|
|
prompt = manager.maybe_convert_prompt(prompt, pipe.tokenizer or pipe.tokenizer_2)
|
|
texts_and_weights = prompt_parser.parse_prompt_attention(prompt)
|
|
if shared.opts.prompt_mean_norm:
|
|
texts_and_weights = normalize_prompt(texts_and_weights)
|
|
texts, text_weights = zip(*texts_and_weights)
|
|
avg_weight = 0
|
|
min_weight = 1
|
|
max_weight = 0
|
|
sections = 0
|
|
|
|
try:
|
|
all_tokens = 0
|
|
for text, weight in zip(texts, text_weights):
|
|
tokens = get_tokens(pipe, 'section', text)
|
|
all_tokens += tokens
|
|
avg_weight += tokens*weight
|
|
min_weight = min(min_weight, weight)
|
|
max_weight = max(max_weight, weight)
|
|
if text != 'BREAK':
|
|
sections += 1
|
|
if all_tokens > 0:
|
|
avg_weight = avg_weight / all_tokens
|
|
debug(f'Prompt tokenizer: parser={shared.opts.prompt_attention} len={len(prompt)} sections={sections} tokens={all_tokens} weights={min_weight:.2f}/{avg_weight:.2f}/{max_weight:.2f}')
|
|
except Exception:
|
|
pass
|
|
debug(f'Prompt: weights={texts_and_weights} time={(time.time() - t0):.3f}')
|
|
|
|
return texts, text_weights
|
|
|
|
|
|
def prepare_embedding_providers(pipe, clip_skip) -> list[EmbeddingsProvider]:
|
|
device = devices.device
|
|
embeddings_providers = []
|
|
if 'StableCascade' in pipe.__class__.__name__:
|
|
embedding_type = -(clip_skip)
|
|
elif 'XL' in pipe.__class__.__name__:
|
|
embedding_type = -(clip_skip + 1)
|
|
else:
|
|
embedding_type = clip_skip
|
|
embedding_args = {
|
|
'truncate': False,
|
|
'returned_embeddings_type': embedding_type,
|
|
'device': device,
|
|
'dtype_for_device_getter': lambda device: devices.dtype,
|
|
}
|
|
if getattr(pipe, "prior_pipe", None) is not None and getattr(pipe.prior_pipe, "tokenizer", None) is not None and getattr(pipe.prior_pipe, "text_encoder", None) is not None:
|
|
provider = EmbeddingsProvider(padding_attention_mask_value=0, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, **embedding_args)
|
|
embeddings_providers.append(provider)
|
|
no_mask_provider = EmbeddingsProvider(padding_attention_mask_value=1 if "sote" in pipe.sd_checkpoint_info.name.lower() else 0, tokenizer=pipe.prior_pipe.tokenizer, text_encoder=pipe.prior_pipe.text_encoder, **embedding_args)
|
|
embeddings_providers.append(no_mask_provider)
|
|
elif getattr(pipe, "tokenizer", None) is not None and getattr(pipe, "text_encoder", None) is not None:
|
|
if pipe.text_encoder.__class__.__name__.startswith('CLIP'):
|
|
sd_models.move_model(pipe.text_encoder, devices.device, force=True)
|
|
provider = EmbeddingsProvider(tokenizer=pipe.tokenizer, text_encoder=pipe.text_encoder, **embedding_args)
|
|
embeddings_providers.append(provider)
|
|
if getattr(pipe, "tokenizer_2", None) is not None and getattr(pipe, "text_encoder_2", None) is not None:
|
|
if pipe.text_encoder_2.__class__.__name__.startswith('CLIP'):
|
|
sd_models.move_model(pipe.text_encoder_2, devices.device, force=True)
|
|
provider = EmbeddingsProvider(tokenizer=pipe.tokenizer_2, text_encoder=pipe.text_encoder_2, **embedding_args)
|
|
embeddings_providers.append(provider)
|
|
if getattr(pipe, "tokenizer_3", None) is not None and getattr(pipe, "text_encoder_3", None) is not None:
|
|
if pipe.text_encoder_3.__class__.__name__.startswith('CLIP'):
|
|
sd_models.move_model(pipe.text_encoder_3, devices.device, force=True)
|
|
provider = EmbeddingsProvider(tokenizer=pipe.tokenizer_3, text_encoder=pipe.text_encoder_3, **embedding_args)
|
|
embeddings_providers.append(provider)
|
|
return embeddings_providers
|
|
|
|
|
|
def pad_to_same_length(pipe, embeds, empty_embedding_providers=None):
|
|
if not hasattr(pipe, 'encode_prompt') and ('StableCascade' not in pipe.__class__.__name__):
|
|
return embeds
|
|
device = devices.device
|
|
if shared.opts.diffusers_zeros_prompt_pad or 'StableDiffusion3' in pipe.__class__.__name__:
|
|
empty_embed = [torch.zeros((1, 77, embeds[0].shape[2]), device=device, dtype=embeds[0].dtype)]
|
|
else:
|
|
try:
|
|
if 'StableCascade' in pipe.__class__.__name__:
|
|
empty_embed = empty_embedding_providers[0].get_embeddings_for_weighted_prompt_fragments(text_batch=[[""]], fragment_weights_batch=[[1]], should_return_tokens=False, device=device)
|
|
empty_embed = [empty_embed]
|
|
else:
|
|
empty_embed = pipe.encode_prompt("")
|
|
except TypeError: # SD1.5
|
|
empty_embed = pipe.encode_prompt("", device, 1, False)
|
|
max_token_count = max([embed.shape[1] for embed in embeds])
|
|
repeats = max_token_count - min([embed.shape[1] for embed in embeds])
|
|
empty_batched = empty_embed[0].to(embeds[0].device).repeat(embeds[0].shape[0], repeats // empty_embed[0].shape[1], 1)
|
|
for i, embed in enumerate(embeds):
|
|
if embed.shape[1] < max_token_count:
|
|
embed = torch.cat([embed, empty_batched], dim=1)
|
|
embeds[i] = embed
|
|
return embeds
|
|
|
|
|
|
def split_prompts(pipe, prompt, SD3 = False):
|
|
if prompt.find("TE2:") != -1:
|
|
prompt, prompt2 = prompt.split("TE2:")
|
|
else:
|
|
prompt2 = prompt
|
|
|
|
if prompt.find("TE3:") != -1:
|
|
prompt, prompt3 = prompt.split("TE3:")
|
|
elif prompt2.find("TE3:") != -1:
|
|
prompt2, prompt3 = prompt2.split("TE3:")
|
|
else:
|
|
prompt3 = prompt
|
|
|
|
if prompt.find("TE4:") != -1:
|
|
prompt, prompt4 = prompt.split("TE4:")
|
|
elif prompt2.find("TE4:") != -1:
|
|
prompt2, prompt4 = prompt2.split("TE4:")
|
|
elif prompt3.find("TE4:") != -1:
|
|
prompt3, prompt4 = prompt3.split("TE4:")
|
|
else:
|
|
prompt4 = prompt
|
|
|
|
prompt = prompt.strip()
|
|
prompt2 = " " if prompt2.strip() == "" else prompt2.strip()
|
|
prompt3 = " " if prompt3.strip() == "" else prompt3.strip()
|
|
prompt4 = " " if prompt4.strip() == "" else prompt4.strip()
|
|
|
|
if SD3 and prompt3 != " ":
|
|
ps, _ws = get_prompts_with_weights(pipe, prompt3)
|
|
prompt3 = " ".join(ps)
|
|
return prompt, prompt2, prompt3, prompt4
|
|
|
|
|
|
def get_weighted_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None):
|
|
device = devices.device
|
|
if prompt is None:
|
|
prompt = ''
|
|
if neg_prompt is None:
|
|
neg_prompt = ''
|
|
SD3 = bool(hasattr(pipe, 'text_encoder_3') and not hasattr(pipe, 'text_encoder_4'))
|
|
prompt, prompt_2, prompt_3, prompt_4 = split_prompts(pipe, prompt, SD3)
|
|
neg_prompt, neg_prompt_2, neg_prompt_3, neg_prompt_4 = split_prompts(pipe, neg_prompt, SD3)
|
|
|
|
if "Flux" in pipe.__class__.__name__: # clip is only used for the pooled embeds
|
|
prompt_embeds, pooled_prompt_embeds, _ = pipe.encode_prompt(prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=1)
|
|
return prompt_embeds, pooled_prompt_embeds, None, None, None, None # no negative support
|
|
|
|
if "Chroma" in pipe.__class__.__name__: # does not use clip and has no pooled embeds
|
|
prompt_embeds, _, prompt_attention_mask, negative_prompt_embeds, _, negative_prompt_attention_mask = pipe.encode_prompt(prompt=prompt, negative_prompt=neg_prompt, device=device, num_images_per_prompt=1)
|
|
return prompt_embeds, None, prompt_attention_mask, negative_prompt_embeds, None, negative_prompt_attention_mask
|
|
|
|
if "HiDreamImage" in pipe.__class__.__name__: # clip is only used for the pooled embeds
|
|
prompt_embeds_t5, negative_prompt_embeds_t5, prompt_embeds_llama3, negative_prompt_embeds_llama3, pooled_prompt_embeds, negative_pooled_prompt_embeds = pipe.encode_prompt(
|
|
prompt=prompt, prompt_2=prompt_2, prompt_3=prompt_3, prompt_4=prompt_4,
|
|
negative_prompt=neg_prompt, negative_prompt_2=neg_prompt_2, negative_prompt_3=neg_prompt_3, negative_prompt_4=neg_prompt_4,
|
|
device=device, num_images_per_prompt=1,
|
|
)
|
|
prompt_embeds = [prompt_embeds_t5, prompt_embeds_llama3]
|
|
negative_prompt_embeds = [negative_prompt_embeds_t5, negative_prompt_embeds_llama3]
|
|
return prompt_embeds, pooled_prompt_embeds, None, negative_prompt_embeds, negative_pooled_prompt_embeds, None
|
|
|
|
if prompt != prompt_2:
|
|
ps = [get_prompts_with_weights(pipe, p) for p in [prompt, prompt_2]]
|
|
ns = [get_prompts_with_weights(pipe, p) for p in [neg_prompt, neg_prompt_2]]
|
|
else:
|
|
ps = 2 * [get_prompts_with_weights(pipe, prompt)]
|
|
ns = 2 * [get_prompts_with_weights(pipe, neg_prompt)]
|
|
|
|
positives, positive_weights = zip(*ps)
|
|
negatives, negative_weights = zip(*ns)
|
|
if hasattr(pipe, "tokenizer_2") and not hasattr(pipe, "tokenizer"):
|
|
positives.pop(0)
|
|
positive_weights.pop(0)
|
|
negatives.pop(0)
|
|
negative_weights.pop(0)
|
|
|
|
embedding_providers = prepare_embedding_providers(pipe, clip_skip)
|
|
if len(embedding_providers) == 0:
|
|
shared.log.error("Prompt encode: cannot find text encoder in model")
|
|
return None, None, None, None, None, None
|
|
empty_embedding_providers = None
|
|
if 'StableCascade' in pipe.__class__.__name__:
|
|
empty_embedding_providers = [embedding_providers[1]]
|
|
embedding_providers = [embedding_providers[0]]
|
|
|
|
prompt_embeds = []
|
|
negative_prompt_embeds = []
|
|
pooled_prompt_embeds = []
|
|
negative_pooled_prompt_embeds = []
|
|
for i in range(len(embedding_providers)):
|
|
if i >= len(positives): # te may be missing/unloaded
|
|
break
|
|
t0 = time.time()
|
|
text = list(positives[i])
|
|
weights = list(positive_weights[i])
|
|
text.append('BREAK')
|
|
weights.append(-1)
|
|
provider_embed = []
|
|
ptokens = 0
|
|
while 'BREAK' in text:
|
|
pos = text.index('BREAK')
|
|
debug(f'Prompt: section="{text[:pos]}" len={len(text[:pos])} weights={weights[:pos]}')
|
|
if len(text[:pos]) > 0:
|
|
embed, ptokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[text[:pos]], fragment_weights_batch=[weights[:pos]], device=device, should_return_tokens=True)
|
|
provider_embed.append(embed)
|
|
text = text[pos + 1:]
|
|
weights = weights[pos + 1:]
|
|
prompt_embeds.append(torch.cat(provider_embed, dim=1))
|
|
# negative prompt has no keywords
|
|
if shared.opts.diffusers_zeros_prompt_pad and len(negatives[i]) == 1 and negatives[i][0] in {"", " "}:
|
|
embed, ntokens = torch.zeros_like(embed), torch.zeros_like(ptokens)
|
|
else:
|
|
embed, ntokens = embedding_providers[i].get_embeddings_for_weighted_prompt_fragments(text_batch=[negatives[i]], fragment_weights_batch=[negative_weights[i]], device=device, should_return_tokens=True)
|
|
negative_prompt_embeds.append(embed)
|
|
debug(f'Prompt: unpadded={prompt_embeds[0].shape} TE{i+1} ptokens={torch.count_nonzero(ptokens)} ntokens={torch.count_nonzero(ntokens)} time={(time.time() - t0):.3f}')
|
|
if SD3:
|
|
t0 = time.time()
|
|
pooled_prompt_embeds.append(embedding_providers[0].get_pooled_embeddings(texts=positives[0] if len(positives[0]) == 1 else [" ".join(positives[0])], device=device))
|
|
pooled_prompt_embeds.append(embedding_providers[1].get_pooled_embeddings(texts=positives[-1] if len(positives[-1]) == 1 else [" ".join(positives[-1])], device=device))
|
|
negative_pooled_prompt_embeds.append(embedding_providers[0].get_pooled_embeddings(texts=negatives[0] if len(negatives[0]) == 1 else [" ".join(negatives[0])], device=device))
|
|
negative_pooled_prompt_embeds.append(embedding_providers[1].get_pooled_embeddings(texts=negatives[-1] if len(negatives[-1]) == 1 else [" ".join(negatives[-1])], device=device))
|
|
pooled_prompt_embeds = torch.cat(pooled_prompt_embeds, dim=-1)
|
|
negative_pooled_prompt_embeds = torch.cat(negative_pooled_prompt_embeds, dim=-1)
|
|
debug(f'Prompt: pooled={pooled_prompt_embeds[0].shape} time={(time.time() - t0):.3f}')
|
|
elif prompt_embeds[-1].shape[-1] > 768:
|
|
t0 = time.time()
|
|
if shared.opts.te_pooled_embeds:
|
|
pooled_prompt_embeds = embedding_providers[-1].text_encoder.text_projection(prompt_embeds[-1][
|
|
torch.arange(prompt_embeds[-1].shape[0], device=device),
|
|
(ptokens.to(dtype=torch.int, device=device) == 49407)
|
|
.int()
|
|
.argmax(dim=-1),
|
|
])
|
|
negative_pooled_prompt_embeds = embedding_providers[-1].text_encoder.text_projection(negative_prompt_embeds[-1][
|
|
torch.arange(negative_prompt_embeds[-1].shape[0], device=device),
|
|
(ntokens.to(dtype=torch.int, device=device) == 49407)
|
|
.int()
|
|
.argmax(dim=-1),
|
|
])
|
|
else:
|
|
try:
|
|
pooled_prompt_embeds = embedding_providers[-1].get_pooled_embeddings(texts=[prompt_2], device=device) if prompt_embeds[-1].shape[-1] > 768 else None
|
|
if shared.opts.diffusers_zeros_prompt_pad and neg_prompt_2 in {"", " "}:
|
|
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) if negative_prompt_embeds[-1].shape[-1] > 768 else None
|
|
else:
|
|
negative_pooled_prompt_embeds = embedding_providers[-1].get_pooled_embeddings(texts=[neg_prompt_2], device=device) if negative_prompt_embeds[-1].shape[-1] > 768 else None
|
|
except Exception:
|
|
pooled_prompt_embeds = None
|
|
negative_pooled_prompt_embeds = None
|
|
debug(f'Prompt: pooled shape={pooled_prompt_embeds[0].shape if pooled_prompt_embeds is not None else None} time={(time.time() - t0):.3f}')
|
|
|
|
prompt_embeds = torch.cat(prompt_embeds, dim=-1) if len(prompt_embeds) > 1 else prompt_embeds[0]
|
|
negative_prompt_embeds = torch.cat(negative_prompt_embeds, dim=-1) if len(negative_prompt_embeds) > 1 else \
|
|
negative_prompt_embeds[0]
|
|
if pooled_prompt_embeds == []:
|
|
pooled_prompt_embeds = None
|
|
if negative_pooled_prompt_embeds == []:
|
|
negative_pooled_prompt_embeds = None
|
|
debug(f'Prompt: positive={prompt_embeds.shape if prompt_embeds is not None else None} pooled={pooled_prompt_embeds.shape if pooled_prompt_embeds is not None else None} negative={negative_prompt_embeds.shape if negative_prompt_embeds is not None else None} pooled={negative_pooled_prompt_embeds.shape if negative_pooled_prompt_embeds is not None else None}')
|
|
if prompt_embeds.shape[1] != negative_prompt_embeds.shape[1]:
|
|
[prompt_embeds, negative_prompt_embeds] = pad_to_same_length(pipe, [prompt_embeds, negative_prompt_embeds], empty_embedding_providers=empty_embedding_providers)
|
|
if SD3:
|
|
device = devices.device
|
|
t5_prompt_embed = pipe._get_t5_prompt_embeds( # pylint: disable=protected-access
|
|
prompt=prompt_3,
|
|
num_images_per_prompt=prompt_embeds.shape[0],
|
|
device=device,
|
|
)
|
|
prompt_embeds = torch.nn.functional.pad(
|
|
prompt_embeds, (0, t5_prompt_embed.shape[-1] - prompt_embeds.shape[-1])
|
|
).to(device)
|
|
prompt_embeds = torch.cat([prompt_embeds, t5_prompt_embed], dim=-2)
|
|
t5_negative_prompt_embed = pipe._get_t5_prompt_embeds( # pylint: disable=protected-access
|
|
prompt=neg_prompt_3,
|
|
num_images_per_prompt=prompt_embeds.shape[0],
|
|
device=device,
|
|
)
|
|
negative_prompt_embeds = torch.nn.functional.pad(
|
|
negative_prompt_embeds, (0, t5_negative_prompt_embed.shape[-1] - negative_prompt_embeds.shape[-1])
|
|
).to(device)
|
|
negative_prompt_embeds = torch.cat([negative_prompt_embeds, t5_negative_prompt_embed], dim=-2)
|
|
return prompt_embeds, pooled_prompt_embeds, None, negative_prompt_embeds, negative_pooled_prompt_embeds, None
|
|
|
|
|
|
def get_xhinker_text_embeddings(pipe, prompt: str = "", neg_prompt: str = "", clip_skip: int = None):
|
|
is_sd3 = hasattr(pipe, 'text_encoder_3')
|
|
prompt, prompt_2, _prompt_3, _ = split_prompts(pipe, prompt, is_sd3)
|
|
neg_prompt, neg_prompt_2, _neg_prompt_3, _ = split_prompts(pipe, neg_prompt, is_sd3)
|
|
try:
|
|
prompt = pipe.maybe_convert_prompt(prompt, pipe.tokenizer)
|
|
neg_prompt = pipe.maybe_convert_prompt(neg_prompt, pipe.tokenizer)
|
|
prompt_2 = pipe.maybe_convert_prompt(prompt_2, pipe.tokenizer_2)
|
|
neg_prompt_2 = pipe.maybe_convert_prompt(neg_prompt_2, pipe.tokenizer_2)
|
|
except Exception:
|
|
pass
|
|
prompt_embed = positive_pooled = negative_embed = negative_pooled = prompt_attention_mask = negative_prompt_attention_mask = None
|
|
|
|
te1_device, te2_device, te3_device = None, None, None
|
|
if hasattr(pipe, "text_encoder") and pipe.text_encoder.device != devices.device:
|
|
te1_device = pipe.text_encoder.device
|
|
sd_models.move_model(pipe.text_encoder, devices.device, force=True)
|
|
if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2.device != devices.device:
|
|
te2_device = pipe.text_encoder_2.device
|
|
sd_models.move_model(pipe.text_encoder_2, devices.device, force=True)
|
|
if hasattr(pipe, "text_encoder_3") and pipe.text_encoder_3.device != devices.device:
|
|
te3_device = pipe.text_encoder_3.device
|
|
sd_models.move_model(pipe.text_encoder_3, devices.device, force=True)
|
|
|
|
if 'StableDiffusion3' in pipe.__class__.__name__:
|
|
prompt_embed, negative_embed, positive_pooled, negative_pooled = get_weighted_text_embeddings_sd3(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, use_t5_encoder=bool(pipe.text_encoder_3))
|
|
elif 'Flux' in pipe.__class__.__name__:
|
|
prompt_embed, positive_pooled = get_weighted_text_embeddings_flux1(pipe=pipe, prompt=prompt, prompt2=prompt_2, device=devices.device)
|
|
elif 'Chroma' in pipe.__class__.__name__:
|
|
prompt_embed, prompt_attention_mask, negative_embed, negative_prompt_attention_mask = get_weighted_text_embeddings_chroma(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, device=devices.device)
|
|
elif 'XL' in pipe.__class__.__name__:
|
|
prompt_embed, negative_embed, positive_pooled, negative_pooled = get_weighted_text_embeddings_sdxl_2p(pipe=pipe, prompt=prompt, prompt_2=prompt_2, neg_prompt=neg_prompt, neg_prompt_2=neg_prompt_2)
|
|
else:
|
|
prompt_embed, negative_embed = get_weighted_text_embeddings_sd15(pipe=pipe, prompt=prompt, neg_prompt=neg_prompt, clip_skip=clip_skip)
|
|
|
|
if te1_device is not None:
|
|
sd_models.move_model(pipe.text_encoder, te1_device, force=True)
|
|
if te2_device is not None:
|
|
sd_models.move_model(pipe.text_encoder_2, te1_device, force=True)
|
|
if te3_device is not None:
|
|
sd_models.move_model(pipe.text_encoder_3, te1_device, force=True)
|
|
|
|
return prompt_embed, positive_pooled, prompt_attention_mask, negative_embed, negative_pooled, negative_prompt_attention_mask
|