1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/prompt_parser_xhinker.py
Vladimir Mandic 00e34ce0d3 cleanup reference models
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-08-08 14:07:33 -04:00

1560 lines
57 KiB
Python

## -----------------------------------------------------------------------------
# Generate unlimited size prompt with weighting for SD3&SDXL&SD15
# If you use sd_embed in your research, please cite the following work:
#
# ```
# @misc{sd_embed_2024,
# author = {Shudong Zhu(Andrew Zhu)},
# title = {Long Prompt Weighted Stable Diffusion Embedding},
# howpublished = {\url{https://github.com/xhinker/sd_embed}},
# year = {2024},
# }
# ```
# Author: Andrew Zhu
# Book: Using Stable Diffusion with Python, https://www.amazon.com/Using-Stable-Diffusion-Python-Generation/dp/1835086373
# Github: https://github.com/xhinker
# Medium: https://medium.com/@xhinker
## -----------------------------------------------------------------------------
import torch
import torch.nn.functional as F
from transformers import CLIPTokenizer, T5Tokenizer
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionXLPipeline
from diffusers import StableDiffusion3Pipeline
from diffusers import FluxPipeline
from diffusers import ChromaPipeline
from modules.prompt_parser import parse_prompt_attention # use built-in A1111 parser
def get_prompts_tokens_with_weights(
clip_tokenizer: CLIPTokenizer
, prompt: str = None
):
"""
Get prompt token ids and weights, this function works for both prompt and negative prompt
Args:
pipe (CLIPTokenizer)
A CLIPTokenizer
prompt (str)
A prompt string with weights
Returns:
text_tokens (list)
A list contains token ids
text_weight (list)
A list contains the correspodent weight of token ids
Example:
import torch
from diffusers_plus.tools.sd_embeddings import get_prompts_tokens_with_weights
from transformers import CLIPTokenizer
clip_tokenizer = CLIPTokenizer.from_pretrained(
"stablediffusionapi/deliberate-v2"
, subfolder = "tokenizer"
, dtype = torch.float16
)
token_id_list, token_weight_list = get_prompts_tokens_with_weights(
clip_tokenizer = clip_tokenizer
,prompt = "a (red:1.5) cat"*70
)
"""
if (prompt is None) or (len(prompt) < 1):
prompt = "empty"
texts_and_weights = parse_prompt_attention(prompt)
text_tokens, text_weights = [], []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
token = clip_tokenizer(
word
, truncation=False # so that tokenize whatever length prompt
).input_ids[1:-1]
# the returned token is a 1d list: [320, 1125, 539, 320]
# merge the new tokens to the all tokens holder: text_tokens
text_tokens = [*text_tokens, *token]
# each token chunk will come with one weight, like ['red cat', 2.0]
# need to expand weight for each token.
chunk_weights = [weight] * len(token)
# append the weight back to the weight holder: text_weights
text_weights = [*text_weights, *chunk_weights]
return text_tokens, text_weights
def get_prompts_tokens_with_weights_t5(
t5_tokenizer: T5Tokenizer,
prompt: str,
add_special_tokens: bool = True
):
"""
Get prompt token ids and weights, this function works for both prompt and negative prompt
"""
if (prompt is None) or (len(prompt) < 1):
prompt = "empty"
texts_and_weights = parse_prompt_attention(prompt)
text_tokens, text_weights, text_masks = [], [], []
for word, weight in texts_and_weights:
# tokenize and discard the starting and the ending token
inputs = t5_tokenizer(
word,
truncation=False, # so that tokenize whatever length prompt
add_special_tokens=add_special_tokens,
return_length=False,
)
token = inputs.input_ids
mask = inputs.attention_mask
# merge the new tokens to the all tokens holder: text_tokens
text_tokens = [*text_tokens, *token]
text_masks = [*text_masks, *mask]
# each token chunk will come with one weight, like ['red cat', 2.0]
# need to expand weight for each token.
chunk_weights = [weight] * len(token)
# append the weight back to the weight holder: text_weights
text_weights = [*text_weights, *chunk_weights]
return text_tokens, text_weights, text_masks
def group_tokens_and_weights(
token_ids: list
, weights: list
, pad_last_block=False
):
"""
Produce tokens and weights in groups and pad the missing tokens
Args:
token_ids (list)
The token ids from tokenizer
weights (list)
The weights list from function get_prompts_tokens_with_weights
pad_last_block (bool)
Control if fill the last token list to 75 tokens with eos
Returns:
new_token_ids (2d list)
new_weights (2d list)
Example:
from diffusers_plus.tools.sd_embeddings import group_tokens_and_weights
token_groups,weight_groups = group_tokens_and_weights(
token_ids = token_id_list
, weights = token_weight_list
)
"""
bos, eos = 49406, 49407
# this will be a 2d list
new_token_ids = []
new_weights = []
while len(token_ids) >= 75:
# get the first 75 tokens
head_75_tokens = [token_ids.pop(0) for _ in range(75)]
head_75_weights = [weights.pop(0) for _ in range(75)]
# extract token ids and weights
temp_77_token_ids = [bos] + head_75_tokens + [eos]
temp_77_weights = [1.0] + head_75_weights + [1.0]
# add 77 token and weights chunk to the holder list
new_token_ids.append(temp_77_token_ids)
new_weights.append(temp_77_weights)
# padding the left
if len(token_ids) > 0:
padding_len = 75 - len(token_ids) if pad_last_block else 0
temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
new_token_ids.append(temp_77_token_ids)
temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
new_weights.append(temp_77_weights)
return new_token_ids, new_weights
def get_weighted_text_embeddings_sd15(
pipe: StableDiffusionPipeline
, prompt: str = ""
, neg_prompt: str = ""
, pad_last_block=False
, clip_skip: int = 0
):
"""
This function can process long prompt with weights, no length limitation
for Stable Diffusion v1.5
Args:
pipe (StableDiffusionPipeline)
prompt (str)
neg_prompt (str)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
Example:
from diffusers import StableDiffusionPipeline
text2img_pipe = StableDiffusionPipeline.from_pretrained(
"stablediffusionapi/deliberate-v2"
, torch_dtype = torch.float16
, safety_checker = None
).to("cuda:0")
prompt_embeds, neg_prompt_embeds = get_weighted_text_embeddings_v15(
pipe = text2img_pipe
, prompt = "a (white) cat"
, neg_prompt = "blur"
)
image = text2img_pipe(
prompt_embeds = prompt_embeds
, negative_prompt_embeds = neg_prompt_embeds
, generator = torch.Generator(text2img_pipe.device).manual_seed(2)
).images[0]
"""
original_clip_layers = pipe.text_encoder.text_model.encoder.layers
if clip_skip > 0:
pipe.text_encoder.text_model.encoder.layers = original_clip_layers[:-clip_skip]
eos = pipe.tokenizer.eos_token_id
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, prompt
)
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, neg_prompt
)
# padding the shorter one
prompt_token_len = len(prompt_tokens)
neg_prompt_token_len = len(neg_prompt_tokens)
if prompt_token_len > neg_prompt_token_len:
# padding the neg_prompt with eos token
neg_prompt_tokens = (
neg_prompt_tokens +
[eos] * abs(prompt_token_len - neg_prompt_token_len)
)
neg_prompt_weights = (
neg_prompt_weights +
[1.0] * abs(prompt_token_len - neg_prompt_token_len)
)
else:
# padding the prompt
prompt_tokens = (
prompt_tokens
+ [eos] * abs(prompt_token_len - neg_prompt_token_len)
)
prompt_weights = (
prompt_weights
+ [1.0] * abs(prompt_token_len - neg_prompt_token_len)
)
embeds = []
neg_embeds = []
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
prompt_tokens.copy()
, prompt_weights.copy()
, pad_last_block=pad_last_block
)
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
neg_prompt_tokens.copy()
, neg_prompt_weights.copy()
, pad_last_block=pad_last_block
)
# get prompt embeddings one by one is not working
# we must embed prompt group by group
for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights
token_tensor = torch.tensor(
[prompt_token_groups[i]]
, dtype=torch.long, device=pipe.text_encoder.device
)
weight_tensor = torch.tensor(
prompt_weight_groups[i]
, dtype=torch.float16
, device=pipe.text_encoder.device
)
token_embedding = pipe.text_encoder(token_tensor)[0].squeeze(0)
for j in range(len(weight_tensor)):
token_embedding[j] = token_embedding[j] * weight_tensor[j]
token_embedding = token_embedding.unsqueeze(0)
embeds.append(token_embedding)
# get negative prompt embeddings with weights
neg_token_tensor = torch.tensor(
[neg_prompt_token_groups[i]]
, dtype=torch.long, device=pipe.text_encoder.device
)
neg_weight_tensor = torch.tensor(
neg_prompt_weight_groups[i]
, dtype=torch.float16
, device=pipe.text_encoder.device
)
neg_token_embedding = pipe.text_encoder(neg_token_tensor)[0].squeeze(0)
for z in range(len(neg_weight_tensor)):
neg_token_embedding[z] = (
neg_token_embedding[z] * neg_weight_tensor[z]
)
neg_token_embedding = neg_token_embedding.unsqueeze(0)
neg_embeds.append(neg_token_embedding)
prompt_embeds = torch.cat(embeds, dim=1)
neg_prompt_embeds = torch.cat(neg_embeds, dim=1)
# recover clip layers
if clip_skip > 0:
pipe.text_encoder.text_model.encoder.layers = original_clip_layers
return prompt_embeds, neg_prompt_embeds
def get_weighted_text_embeddings_sdxl(
pipe: StableDiffusionXLPipeline
, prompt: str = ""
, neg_prompt: str = ""
, pad_last_block=True
):
"""
This function can process long prompt with weights, no length limitation
for Stable Diffusion XL
Args:
pipe (StableDiffusionPipeline)
prompt (str)
neg_prompt (str)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
Example:
from diffusers import StableDiffusionPipeline
text2img_pipe = StableDiffusionPipeline.from_pretrained(
"stablediffusionapi/deliberate-v2"
, torch_dtype = torch.float16
, safety_checker = None
).to("cuda:0")
prompt_embeds, neg_prompt_embeds = get_weighted_text_embeddings_v15(
pipe = text2img_pipe
, prompt = "a (white) cat"
, neg_prompt = "blur"
)
image = text2img_pipe(
prompt_embeds = prompt_embeds
, negative_prompt_embeds = neg_prompt_embeds
, generator = torch.Generator(text2img_pipe.device).manual_seed(2)
).images[0]
"""
eos = pipe.tokenizer.eos_token_id
# tokenizer 1
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, prompt
)
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, neg_prompt
)
# tokenizer 2
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(
pipe.tokenizer_2, prompt
)
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(
pipe.tokenizer_2, neg_prompt
)
# padding the shorter one
prompt_token_len = len(prompt_tokens)
neg_prompt_token_len = len(neg_prompt_tokens)
if prompt_token_len > neg_prompt_token_len:
# padding the neg_prompt with eos token
neg_prompt_tokens = (
neg_prompt_tokens +
[eos] * abs(prompt_token_len - neg_prompt_token_len)
)
neg_prompt_weights = (
neg_prompt_weights +
[1.0] * abs(prompt_token_len - neg_prompt_token_len)
)
else:
# padding the prompt
prompt_tokens = (
prompt_tokens
+ [eos] * abs(prompt_token_len - neg_prompt_token_len)
)
prompt_weights = (
prompt_weights
+ [1.0] * abs(prompt_token_len - neg_prompt_token_len)
)
# padding the shorter one for token set 2
prompt_token_len_2 = len(prompt_tokens_2)
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
if prompt_token_len_2 > neg_prompt_token_len_2:
# padding the neg_prompt with eos token
neg_prompt_tokens_2 = (
neg_prompt_tokens_2 +
[eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
neg_prompt_weights_2 = (
neg_prompt_weights_2 +
[1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
else:
# padding the prompt
prompt_tokens_2 = (
prompt_tokens_2
+ [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
prompt_weights_2 = (
prompt_weights_2
+ [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
embeds = []
neg_embeds = []
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
prompt_tokens.copy()
, prompt_weights.copy()
, pad_last_block=pad_last_block
)
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
neg_prompt_tokens.copy()
, neg_prompt_weights.copy()
, pad_last_block=pad_last_block
)
prompt_token_groups_2, _prompt_weight_groups_2 = group_tokens_and_weights(
prompt_tokens_2.copy()
, prompt_weights_2.copy()
, pad_last_block=pad_last_block
)
neg_prompt_token_groups_2, _neg_prompt_weight_groups_2 = group_tokens_and_weights(
neg_prompt_tokens_2.copy()
, neg_prompt_weights_2.copy()
, pad_last_block=pad_last_block
)
# get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights
token_tensor = torch.tensor(
[prompt_token_groups[i]]
, dtype=torch.long, device=pipe.text_encoder.device
)
weight_tensor = torch.tensor(
prompt_weight_groups[i]
, dtype=torch.float16
, device=pipe.text_encoder.device
)
token_tensor_2 = torch.tensor(
[prompt_token_groups_2[i]]
, dtype=torch.long, device=pipe.text_encoder_2.device
)
# use first text encoder
prompt_embeds_1 = pipe.text_encoder(
token_tensor.to(pipe.text_encoder.device)
, output_hidden_states=True
)
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
# use second text encoder
prompt_embeds_2 = pipe.text_encoder_2(
token_tensor_2.to(pipe.text_encoder_2.device)
, output_hidden_states=True
)
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
pooled_prompt_embeds = prompt_embeds_2[0]
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0).to(pipe.text_encoder.device)
for j in range(len(weight_tensor)):
if weight_tensor[j] != 1.0:
# ow = weight_tensor[j] - 1
# optional process
# To map number of (0,1) to (-1,1)
# tanh_weight = (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
# weight = 1 + tanh_weight
# add weight method 1:
# token_embedding[j] = token_embedding[j] * weight
# token_embedding[j] = (
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight
# )
# add weight method 2:
# token_embedding[j] = (
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
# )
# add weight method 3:
token_embedding[j] = token_embedding[j] * weight_tensor[j]
token_embedding = token_embedding.unsqueeze(0)
embeds.append(token_embedding)
# get negative prompt embeddings with weights
neg_token_tensor = torch.tensor(
[neg_prompt_token_groups[i]]
, dtype=torch.long, device=pipe.text_encoder.device
)
neg_token_tensor_2 = torch.tensor(
[neg_prompt_token_groups_2[i]]
, dtype=torch.long, device=pipe.text_encoder_2.device
)
neg_weight_tensor = torch.tensor(
neg_prompt_weight_groups[i]
, dtype=torch.float16
, device=pipe.text_encoder.device
)
# use first text encoder
neg_prompt_embeds_1 = pipe.text_encoder(
neg_token_tensor.to(pipe.text_encoder.device)
, output_hidden_states=True
)
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
# use second text encoder
neg_prompt_embeds_2 = pipe.text_encoder_2(
neg_token_tensor_2.to(pipe.text_encoder_2.device)
, output_hidden_states=True
)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0).to(pipe.text_encoder.device)
for z in range(len(neg_weight_tensor)):
if neg_weight_tensor[z] != 1.0:
# ow = neg_weight_tensor[z] - 1
# neg_weight = 1 + (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
# add weight method 1:
# neg_token_embedding[z] = neg_token_embedding[z] * neg_weight
# neg_token_embedding[z] = (
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight
# )
# add weight method 2:
# neg_token_embedding[z] = (
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
# )
# add weight method 3:
neg_token_embedding[z] = neg_token_embedding[z] * neg_weight_tensor[z]
neg_token_embedding = neg_token_embedding.unsqueeze(0)
neg_embeds.append(neg_token_embedding)
prompt_embeds = torch.cat(embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def get_weighted_text_embeddings_sdxl_refiner(
pipe: StableDiffusionXLPipeline
, prompt: str = ""
, neg_prompt: str = ""
):
"""
This function can process long prompt with weights, no length limitation
for Stable Diffusion XL
Args:
pipe (StableDiffusionPipeline)
prompt (str)
neg_prompt (str)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
Example:
from diffusers import StableDiffusionPipeline
text2img_pipe = StableDiffusionPipeline.from_pretrained(
"stablediffusionapi/deliberate-v2"
, torch_dtype = torch.float16
, safety_checker = None
).to("cuda:0")
prompt_embeds, neg_prompt_embeds = get_weighted_text_embeddings_v15(
pipe = text2img_pipe
, prompt = "a (white) cat"
, neg_prompt = "blur"
)
image = text2img_pipe(
prompt_embeds = prompt_embeds
, negative_prompt_embeds = neg_prompt_embeds
, generator = torch.Generator(text2img_pipe.device).manual_seed(2)
).images[0]
"""
eos = 49407 # pipe.tokenizer.eos_token_id
# tokenizer 2
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(
pipe.tokenizer_2, prompt
)
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(
pipe.tokenizer_2, neg_prompt
)
# padding the shorter one for token set 2
prompt_token_len_2 = len(prompt_tokens_2)
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
if prompt_token_len_2 > neg_prompt_token_len_2:
# padding the neg_prompt with eos token
neg_prompt_tokens_2 = (
neg_prompt_tokens_2 +
[eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
neg_prompt_weights_2 = (
neg_prompt_weights_2 +
[1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
else:
# padding the prompt
prompt_tokens_2 = (
prompt_tokens_2
+ [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
prompt_weights_2 = (
prompt_weights_2
+ [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
embeds = []
neg_embeds = []
prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(
prompt_tokens_2.copy()
, prompt_weights_2.copy()
)
neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(
neg_prompt_tokens_2.copy()
, neg_prompt_weights_2.copy()
)
# get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups_2)):
# get positive prompt embeddings with weights
token_tensor_2 = torch.tensor(
[prompt_token_groups_2[i]]
, dtype=torch.long, device=pipe.text_encoder_2.device
)
weight_tensor_2 = torch.tensor(
prompt_weight_groups_2[i]
, dtype=torch.float16
, device=pipe.text_encoder_2.device
)
# use second text encoder
prompt_embeds_2 = pipe.text_encoder_2(
token_tensor_2.to(pipe.text_encoder_2.device)
, output_hidden_states=True
)
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
pooled_prompt_embeds = prompt_embeds_2[0]
prompt_embeds_list = [prompt_embeds_2_hidden_states]
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
for j in range(len(weight_tensor_2)):
if weight_tensor_2[j] != 1.0:
# ow = weight_tensor_2[j] - 1
# optional process
# To map number of (0,1) to (-1,1)
# tanh_weight = (math.exp(ow) / (math.exp(ow) + 1) - 0.5) * 2
# weight = 1 + tanh_weight
# add weight method 1:
# token_embedding[j] = token_embedding[j] * weight
# token_embedding[j] = (
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight
# )
# add weight method 2:
token_embedding[j] = (
token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor_2[j]
)
token_embedding = token_embedding.unsqueeze(0)
embeds.append(token_embedding)
# get negative prompt embeddings with weights
neg_token_tensor_2 = torch.tensor(
[neg_prompt_token_groups_2[i]]
, dtype=torch.long, device=pipe.text_encoder_2.device
)
neg_weight_tensor_2 = torch.tensor(
neg_prompt_weight_groups_2[i]
, dtype=torch.float16
, device=pipe.text_encoder_2.device
)
# use second text encoder
neg_prompt_embeds_2 = pipe.text_encoder_2(
neg_token_tensor_2.to(pipe.text_encoder_2.device)
, output_hidden_states=True
)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
neg_prompt_embeds_list = [neg_prompt_embeds_2_hidden_states]
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
for z in range(len(neg_weight_tensor_2)):
if neg_weight_tensor_2[z] != 1.0:
# ow = neg_weight_tensor_2[z] - 1
# neg_weight = 1 + (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
# add weight method 1:
# neg_token_embedding[z] = neg_token_embedding[z] * neg_weight
# neg_token_embedding[z] = (
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight
# )
# add weight method 2:
neg_token_embedding[z] = (
neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) *
neg_weight_tensor_2[z]
)
neg_token_embedding = neg_token_embedding.unsqueeze(0)
neg_embeds.append(neg_token_embedding)
prompt_embeds = torch.cat(embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def get_weighted_text_embeddings_sdxl_2p(
pipe: StableDiffusionXLPipeline
, prompt: str = ""
, prompt_2: str = None
, neg_prompt: str = ""
, neg_prompt_2: str = None
):
"""
This function can process long prompt with weights, no length limitation
for Stable Diffusion XL, support two prompt sets.
Args:
pipe (StableDiffusionPipeline)
prompt (str)
neg_prompt (str)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
Example:
from diffusers import StableDiffusionPipeline
text2img_pipe = StableDiffusionPipeline.from_pretrained(
"stablediffusionapi/deliberate-v2"
, torch_dtype = torch.float16
, safety_checker = None
).to("cuda:0")
prompt_embeds, neg_prompt_embeds = get_weighted_text_embeddings_v15(
pipe = text2img_pipe
, prompt = "a (white) cat"
, neg_prompt = "blur"
)
image = text2img_pipe(
prompt_embeds = prompt_embeds
, negative_prompt_embeds = neg_prompt_embeds
, generator = torch.Generator(text2img_pipe.device).manual_seed(2)
).images[0]
"""
prompt_2 = prompt_2 or prompt
neg_prompt_2 = neg_prompt_2 or neg_prompt
eos = pipe.tokenizer.eos_token_id
# tokenizer 1
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, prompt
)
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, neg_prompt
)
# tokenizer 2
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(
pipe.tokenizer_2, prompt_2
)
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(
pipe.tokenizer_2, neg_prompt_2
)
# padding the shorter one
prompt_token_len = len(prompt_tokens)
neg_prompt_token_len = len(neg_prompt_tokens)
if prompt_token_len > neg_prompt_token_len:
# padding the neg_prompt with eos token
neg_prompt_tokens = (
neg_prompt_tokens +
[eos] * abs(prompt_token_len - neg_prompt_token_len)
)
neg_prompt_weights = (
neg_prompt_weights +
[1.0] * abs(prompt_token_len - neg_prompt_token_len)
)
else:
# padding the prompt
prompt_tokens = (
prompt_tokens
+ [eos] * abs(prompt_token_len - neg_prompt_token_len)
)
prompt_weights = (
prompt_weights
+ [1.0] * abs(prompt_token_len - neg_prompt_token_len)
)
# padding the shorter one for token set 2
prompt_token_len_2 = len(prompt_tokens_2)
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
if prompt_token_len_2 > neg_prompt_token_len_2:
# padding the neg_prompt with eos token
neg_prompt_tokens_2 = (
neg_prompt_tokens_2 +
[eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
neg_prompt_weights_2 = (
neg_prompt_weights_2 +
[1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
else:
# padding the prompt
prompt_tokens_2 = (
prompt_tokens_2
+ [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
prompt_weights_2 = (
prompt_weights_2
+ [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
# now, need to ensure prompt and prompt_2 has the same lemgth
prompt_token_len = len(prompt_tokens)
prompt_token_len_2 = len(prompt_tokens_2)
if prompt_token_len > prompt_token_len_2:
prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len - prompt_token_len_2)
prompt_weights_2 = prompt_weights_2 + [1.0] * abs(prompt_token_len - prompt_token_len_2)
else:
prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - prompt_token_len_2)
prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - prompt_token_len_2)
# now, need to ensure neg_prompt and net_prompt_2 has the same lemgth
neg_prompt_token_len = len(neg_prompt_tokens)
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
if neg_prompt_token_len > neg_prompt_token_len_2:
neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(neg_prompt_token_len - neg_prompt_token_len_2)
neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(neg_prompt_token_len - neg_prompt_token_len_2)
else:
neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(neg_prompt_token_len - neg_prompt_token_len_2)
neg_prompt_weights = neg_prompt_weights + [1.0] * abs(neg_prompt_token_len - neg_prompt_token_len_2)
embeds = []
neg_embeds = []
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
prompt_tokens.copy()
, prompt_weights.copy()
)
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
neg_prompt_tokens.copy()
, neg_prompt_weights.copy()
)
prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(
prompt_tokens_2.copy()
, prompt_weights_2.copy()
)
neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(
neg_prompt_tokens_2.copy()
, neg_prompt_weights_2.copy()
)
# get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights
token_tensor = torch.tensor(
[prompt_token_groups[i]]
, dtype=torch.long, device=pipe.text_encoder.device
)
weight_tensor = torch.tensor(
prompt_weight_groups[i]
, device=pipe.text_encoder.device
)
token_tensor_2 = torch.tensor(
[prompt_token_groups_2[i]]
, device=pipe.text_encoder_2.device
)
weight_tensor_2 = torch.tensor(
prompt_weight_groups_2[i]
, device=pipe.text_encoder_2.device
)
# use first text encoder
prompt_embeds_1 = pipe.text_encoder(
token_tensor.to(pipe.text_encoder.device)
, output_hidden_states=True
)
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
# use second text encoder
prompt_embeds_2 = pipe.text_encoder_2(
token_tensor_2.to(pipe.text_encoder_2.device)
, output_hidden_states=True
)
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
pooled_prompt_embeds = prompt_embeds_2[0]
prompt_embeds_1_hidden_states = prompt_embeds_1_hidden_states.squeeze(0)
prompt_embeds_2_hidden_states = prompt_embeds_2_hidden_states.squeeze(0)
for j in range(len(weight_tensor)):
if weight_tensor[j] != 1.0:
prompt_embeds_1_hidden_states[j] = (
prompt_embeds_1_hidden_states[-1] + (
prompt_embeds_1_hidden_states[j] - prompt_embeds_1_hidden_states[-1]) * weight_tensor[j]
)
if weight_tensor_2[j] != 1.0:
prompt_embeds_2_hidden_states[j] = (
prompt_embeds_2_hidden_states[-1] + (
prompt_embeds_2_hidden_states[j] - prompt_embeds_2_hidden_states[-1]) * weight_tensor_2[j]
)
prompt_embeds_1_hidden_states = prompt_embeds_1_hidden_states.unsqueeze(0)
prompt_embeds_2_hidden_states = prompt_embeds_2_hidden_states.unsqueeze(0)
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
token_embedding = torch.cat(prompt_embeds_list, dim=-1)
embeds.append(token_embedding)
# get negative prompt embeddings with weights
neg_token_tensor = torch.tensor(
[neg_prompt_token_groups[i]]
, device=pipe.text_encoder.device
)
neg_token_tensor_2 = torch.tensor(
[neg_prompt_token_groups_2[i]]
, device=pipe.text_encoder_2.device
)
neg_weight_tensor = torch.tensor(
neg_prompt_weight_groups[i]
, device=pipe.text_encoder.device
)
neg_weight_tensor_2 = torch.tensor(
neg_prompt_weight_groups_2[i]
, device=pipe.text_encoder_2.device
)
# use first text encoder
neg_prompt_embeds_1 = pipe.text_encoder(
neg_token_tensor.to(pipe.text_encoder.device)
, output_hidden_states=True
)
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
# use second text encoder
neg_prompt_embeds_2 = pipe.text_encoder_2(
neg_token_tensor_2.to(pipe.text_encoder_2.device)
, output_hidden_states=True
)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1_hidden_states.squeeze(0)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2_hidden_states.squeeze(0)
for z in range(len(neg_weight_tensor)):
if neg_weight_tensor[z] != 1.0:
neg_prompt_embeds_1_hidden_states[z] = (
neg_prompt_embeds_1_hidden_states[-1] + (
neg_prompt_embeds_1_hidden_states[z] - neg_prompt_embeds_1_hidden_states[-1]) *
neg_weight_tensor[z]
)
if neg_weight_tensor_2[z] != 1.0:
neg_prompt_embeds_2_hidden_states[z] = (
neg_prompt_embeds_2_hidden_states[-1] + (
neg_prompt_embeds_2_hidden_states[z] - neg_prompt_embeds_2_hidden_states[-1]) *
neg_weight_tensor_2[z]
)
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1_hidden_states.unsqueeze(0)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2_hidden_states.unsqueeze(0)
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
neg_token_embedding = torch.cat(neg_prompt_embeds_list, dim=-1)
neg_embeds.append(neg_token_embedding)
prompt_embeds = torch.cat(embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def get_weighted_text_embeddings_sd3(
pipe: StableDiffusion3Pipeline
, prompt: str = ""
, neg_prompt: str = ""
, pad_last_block=True
, use_t5_encoder=True
):
"""
This function can process long prompt with weights, no length limitation
for Stable Diffusion 3
Args:
pipe (StableDiffusionPipeline)
prompt (str)
neg_prompt (str)
Returns:
sd3_prompt_embeds (torch.Tensor)
sd3_neg_prompt_embeds (torch.Tensor)
pooled_prompt_embeds (torch.Tensor)
negative_pooled_prompt_embeds (torch.Tensor)
"""
eos = pipe.tokenizer.eos_token_id
# tokenizer 1
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, prompt
)
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, neg_prompt
)
# tokenizer 2
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(
pipe.tokenizer_2, prompt
)
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(
pipe.tokenizer_2, neg_prompt
)
# tokenizer 3
prompt_tokens_3, prompt_weights_3, _ = get_prompts_tokens_with_weights_t5(
pipe.tokenizer_3, prompt
)
neg_prompt_tokens_3, neg_prompt_weights_3, _ = get_prompts_tokens_with_weights_t5(
pipe.tokenizer_3, neg_prompt
)
# padding the shorter one
prompt_token_len = len(prompt_tokens)
neg_prompt_token_len = len(neg_prompt_tokens)
if prompt_token_len > neg_prompt_token_len:
# padding the neg_prompt with eos token
neg_prompt_tokens = (
neg_prompt_tokens +
[eos] * abs(prompt_token_len - neg_prompt_token_len)
)
neg_prompt_weights = (
neg_prompt_weights +
[1.0] * abs(prompt_token_len - neg_prompt_token_len)
)
else:
# padding the prompt
prompt_tokens = (
prompt_tokens
+ [eos] * abs(prompt_token_len - neg_prompt_token_len)
)
prompt_weights = (
prompt_weights
+ [1.0] * abs(prompt_token_len - neg_prompt_token_len)
)
# padding the shorter one for token set 2
prompt_token_len_2 = len(prompt_tokens_2)
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
if prompt_token_len_2 > neg_prompt_token_len_2:
# padding the neg_prompt with eos token
neg_prompt_tokens_2 = (
neg_prompt_tokens_2 +
[eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
neg_prompt_weights_2 = (
neg_prompt_weights_2 +
[1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
else:
# padding the prompt
prompt_tokens_2 = (
prompt_tokens_2
+ [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
prompt_weights_2 = (
prompt_weights_2
+ [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
)
embeds = []
neg_embeds = []
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
prompt_tokens.copy()
, prompt_weights.copy()
, pad_last_block=pad_last_block
)
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
neg_prompt_tokens.copy()
, neg_prompt_weights.copy()
, pad_last_block=pad_last_block
)
prompt_token_groups_2, _prompt_weight_groups_2 = group_tokens_and_weights(
prompt_tokens_2.copy()
, prompt_weights_2.copy()
, pad_last_block=pad_last_block
)
neg_prompt_token_groups_2, _neg_prompt_weight_groups_2 = group_tokens_and_weights(
neg_prompt_tokens_2.copy()
, neg_prompt_weights_2.copy()
, pad_last_block=pad_last_block
)
# get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights
token_tensor = torch.tensor(
[prompt_token_groups[i]]
, dtype=torch.long, device=pipe.text_encoder.device
)
weight_tensor = torch.tensor(
prompt_weight_groups[i]
, dtype=torch.float16
, device=pipe.text_encoder.device
)
token_tensor_2 = torch.tensor(
[prompt_token_groups_2[i]]
, dtype=torch.long, device=pipe.text_encoder_2.device
)
# use first text encoder
prompt_embeds_1 = pipe.text_encoder(
token_tensor.to(pipe.text_encoder.device)
, output_hidden_states=True
)
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
pooled_prompt_embeds_1 = prompt_embeds_1[0]
# use second text encoder
prompt_embeds_2 = pipe.text_encoder_2(
token_tensor_2.to(pipe.text_encoder_2.device)
, output_hidden_states=True
)
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
pooled_prompt_embeds_2 = prompt_embeds_2[0]
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0).to(pipe.text_encoder.device)
for j in range(len(weight_tensor)):
if weight_tensor[j] != 1.0:
# ow = weight_tensor[j] - 1
# optional process
# To map number of (0,1) to (-1,1)
# tanh_weight = (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
# weight = 1 + tanh_weight
# add weight method 1:
# token_embedding[j] = token_embedding[j] * weight
# token_embedding[j] = (
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight
# )
# add weight method 2:
# token_embedding[j] = (
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
# )
# add weight method 3:
token_embedding[j] = token_embedding[j] * weight_tensor[j]
token_embedding = token_embedding.unsqueeze(0)
embeds.append(token_embedding)
# get negative prompt embeddings with weights
neg_token_tensor = torch.tensor(
[neg_prompt_token_groups[i]]
, dtype=torch.long, device=pipe.text_encoder.device
)
neg_token_tensor_2 = torch.tensor(
[neg_prompt_token_groups_2[i]]
, dtype=torch.long, device=pipe.text_encoder_2.device
)
neg_weight_tensor = torch.tensor(
neg_prompt_weight_groups[i]
, dtype=torch.float16
, device=pipe.text_encoder.device
)
# use first text encoder
neg_prompt_embeds_1 = pipe.text_encoder(
neg_token_tensor.to(pipe.text_encoder.device)
, output_hidden_states=True
)
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
negative_pooled_prompt_embeds_1 = neg_prompt_embeds_1[0]
# use second text encoder
neg_prompt_embeds_2 = pipe.text_encoder_2(
neg_token_tensor_2.to(pipe.text_encoder_2.device)
, output_hidden_states=True
)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
negative_pooled_prompt_embeds_2 = neg_prompt_embeds_2[0]
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0).to(pipe.text_encoder.device)
for z in range(len(neg_weight_tensor)):
if neg_weight_tensor[z] != 1.0:
# ow = neg_weight_tensor[z] - 1
# neg_weight = 1 + (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
# add weight method 1:
# neg_token_embedding[z] = neg_token_embedding[z] * neg_weight
# neg_token_embedding[z] = (
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight
# )
# add weight method 2:
# neg_token_embedding[z] = (
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
# )
# add weight method 3:
neg_token_embedding[z] = neg_token_embedding[z] * neg_weight_tensor[z]
neg_token_embedding = neg_token_embedding.unsqueeze(0)
neg_embeds.append(neg_token_embedding)
prompt_embeds = torch.cat(embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
negative_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2],
dim=-1)
if use_t5_encoder and pipe.text_encoder_3:
# ----------------- generate positive t5 embeddings --------------------
prompt_tokens_3 = torch.tensor([prompt_tokens_3], dtype=torch.long)
t5_prompt_embeds = pipe.text_encoder_3(prompt_tokens_3.to(pipe.text_encoder_3.device))[0].squeeze(0)
t5_prompt_embeds = t5_prompt_embeds.to(device=pipe.text_encoder_3.device)
# add weight to t5 prompt
for z in range(len(prompt_weights_3)):
if prompt_weights_3[z] != 1.0:
t5_prompt_embeds[z] = t5_prompt_embeds[z] * prompt_weights_3[z]
t5_prompt_embeds = t5_prompt_embeds.unsqueeze(0)
else:
t5_prompt_embeds = torch.zeros(1, 4096, dtype=prompt_embeds.dtype).unsqueeze(0)
t5_prompt_embeds = t5_prompt_embeds.to(device=pipe.text_encoder_3.device)
# merge with the clip embedding 1 and clip embedding 2
clip_prompt_embeds = torch.nn.functional.pad(
prompt_embeds, (0, t5_prompt_embeds.shape[-1] - prompt_embeds.shape[-1])
)
sd3_prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embeds], dim=-2)
if use_t5_encoder and pipe.text_encoder_3:
# ---------------------- get neg t5 embeddings -------------------------
neg_prompt_tokens_3 = torch.tensor([neg_prompt_tokens_3], dtype=torch.long)
t5_neg_prompt_embeds = pipe.text_encoder_3(neg_prompt_tokens_3.to(pipe.text_encoder_3.device))[0].squeeze(0)
t5_neg_prompt_embeds = t5_neg_prompt_embeds.to(device=pipe.text_encoder_3.device)
# add weight to neg t5 embeddings
for z in range(len(neg_prompt_weights_3)):
if neg_prompt_weights_3[z] != 1.0:
t5_neg_prompt_embeds[z] = t5_neg_prompt_embeds[z] * neg_prompt_weights_3[z]
t5_neg_prompt_embeds = t5_neg_prompt_embeds.unsqueeze(0)
else:
t5_neg_prompt_embeds = torch.zeros(1, 4096, dtype=prompt_embeds.dtype).unsqueeze(0)
t5_neg_prompt_embeds = t5_prompt_embeds.to(device=pipe.text_encoder_3.device)
clip_neg_prompt_embeds = torch.nn.functional.pad(
negative_prompt_embeds, (0, t5_neg_prompt_embeds.shape[-1] - negative_prompt_embeds.shape[-1])
)
sd3_neg_prompt_embeds = torch.cat([clip_neg_prompt_embeds, t5_neg_prompt_embeds], dim=-2)
# padding
size_diff = sd3_neg_prompt_embeds.size(1) - sd3_prompt_embeds.size(1)
# Calculate padding. Format for pad is (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
# Since we are padding along the second dimension (axis=1), we need (0, 0, padding_top, padding_bottom, 0, 0)
# Here padding_top will be 0 and padding_bottom will be size_diff
# Check if padding is needed
if size_diff > 0:
padding = (0, 0, 0, abs(size_diff), 0, 0)
sd3_prompt_embeds = F.pad(sd3_prompt_embeds, padding)
elif size_diff < 0:
padding = (0, 0, 0, abs(size_diff), 0, 0)
sd3_neg_prompt_embeds = F.pad(sd3_neg_prompt_embeds, padding)
return sd3_prompt_embeds, sd3_neg_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def get_weighted_text_embeddings_flux1(
pipe: FluxPipeline
, prompt: str = ""
, prompt2: str = None
, device=None
):
"""
This function can process long prompt with weights for flux1 model
Args:
Returns:
"""
prompt2 = prompt if prompt2 is None else prompt2
if device is None:
device = pipe.text_encoder.device
# tokenizer 1 - openai/clip-vit-large-patch14
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
pipe.tokenizer, prompt
)
# tokenizer 2 - google/t5-v1_1-xxl
prompt_tokens_2, prompt_weights_2, _ = get_prompts_tokens_with_weights_t5(
pipe.tokenizer_2, prompt2
)
prompt_token_groups, _prompt_weight_groups = group_tokens_and_weights(
prompt_tokens.copy()
, prompt_weights.copy()
, pad_last_block=True
)
# # get positive prompt embeddings, flux1 use only text_encoder 1 pooled embeddings
# token_tensor = torch.tensor(
# [prompt_token_groups[0]]
# , dtype = torch.long, device = device
# )
# # use first text encoder
# prompt_embeds_1 = pipe.text_encoder(
# token_tensor.to(device)
# , output_hidden_states = False
# )
# pooled_prompt_embeds_1 = prompt_embeds_1.pooler_output
# prompt_embeds = pooled_prompt_embeds_1.to(dtype = pipe.text_encoder.dtype, device = device)
# use avg pooling embeddings
pool_embeds_list = []
for token_group in prompt_token_groups:
token_tensor = torch.tensor(
[token_group]
, dtype=torch.long
, device=device
)
prompt_embeds_1 = pipe.text_encoder(
token_tensor.to(device)
, output_hidden_states=False
)
pooled_prompt_embeds = prompt_embeds_1.pooler_output.squeeze(0)
pool_embeds_list.append(pooled_prompt_embeds)
prompt_embeds = torch.stack(pool_embeds_list, dim=0)
# get the avg pool
prompt_embeds = prompt_embeds.mean(dim=0, keepdim=True)
# prompt_embeds = prompt_embeds.unsqueeze(0)
prompt_embeds = prompt_embeds.to(dtype=pipe.text_encoder.dtype, device=device)
# generate positive t5 embeddings
prompt_tokens_2 = torch.tensor([prompt_tokens_2], dtype=torch.long)
t5_prompt_embeds = pipe.text_encoder_2(prompt_tokens_2.to(device))[0].squeeze(0)
t5_prompt_embeds = t5_prompt_embeds.to(device=device)
# add weight to t5 prompt
for z in range(len(prompt_weights_2)):
if prompt_weights_2[z] != 1.0:
t5_prompt_embeds[z] = t5_prompt_embeds[z] * prompt_weights_2[z]
t5_prompt_embeds = t5_prompt_embeds.unsqueeze(0)
t5_prompt_embeds = t5_prompt_embeds.to(dtype=pipe.text_encoder_2.dtype, device=device)
return t5_prompt_embeds, prompt_embeds
def get_weighted_text_embeddings_chroma(
pipe: ChromaPipeline,
prompt: str = "",
neg_prompt: str = "",
device=None
):
"""
This function can process long prompt with weights for Chroma model
Args:
pipe (ChromaPipeline)
prompt (str)
neg_prompt (str)
device (torch.device, optional): Device to run the embeddings on.
Returns:
prompt_embeds (torch.Tensor)
prompt_attention_mask (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
neg_prompt_attention_mask (torch.Tensor)
"""
if device is None:
device = pipe.text_encoder.device
dtype = pipe.text_encoder.dtype
prompt_tokens, prompt_weights, prompt_masks = get_prompts_tokens_with_weights_t5(
pipe.tokenizer, prompt, add_special_tokens=False
)
neg_prompt_tokens, neg_prompt_weights, neg_prompt_masks = get_prompts_tokens_with_weights_t5(
pipe.tokenizer, neg_prompt, add_special_tokens=False
)
prompt_tokens, prompt_weights, prompt_masks = pad_prompt_tokens_to_length_chroma(
pipe,
prompt_tokens,
prompt_weights,
prompt_masks
)
prompt_embeds, prompt_masks = get_weighted_prompt_embeds_with_attention_mask_chroma(
pipe,
prompt_tokens,
prompt_weights,
prompt_masks,
device=device,
dtype=dtype)
neg_prompt_tokens, neg_prompt_weights, neg_prompt_masks = pad_prompt_tokens_to_length_chroma(
pipe,
neg_prompt_tokens,
neg_prompt_weights,
neg_prompt_masks
)
neg_prompt_embeds, neg_prompt_masks = get_weighted_prompt_embeds_with_attention_mask_chroma(
pipe,
neg_prompt_tokens,
neg_prompt_weights,
neg_prompt_masks,
device=device,
dtype=dtype)
# debug, will be removed later
return prompt_embeds, prompt_masks, neg_prompt_embeds, neg_prompt_masks
def get_weighted_prompt_embeds_with_attention_mask_chroma(
pipe: ChromaPipeline,
tokens,
weights,
masks,
device,
dtype
):
prompt_tokens = torch.tensor([tokens], dtype=torch.long, device=device)
prompt_masks = torch.tensor([masks], dtype=torch.long, device=device)
prompt_embeds = pipe.text_encoder(prompt_tokens, output_hidden_states=False, attention_mask=prompt_masks)[0].squeeze(0)
for z in range(len(weights)):
if weights[z] != 1.0:
prompt_embeds[z] = prompt_embeds[z] * weights[z]
prompt_embeds = prompt_embeds.unsqueeze(0).to(dtype=dtype, device=device)
return prompt_embeds, prompt_masks
def pad_prompt_tokens_to_length_chroma(pipe, input_tokens, input_weights, input_masks, min_length=5, add_eos_token=True):
"""
Implementation of Chroma's padding for prompt embeddings.
Pads the embeddings to the maximum length found in the batch, while ensuring
that the padding tokens are masked correctly while keeping at least one padding and one eos token unmasked.
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
"""
output_tokens = input_tokens.copy()
output_weights = input_weights.copy()
output_masks = input_masks.copy()
pad_token_id = pipe.tokenizer.pad_token_id
eos_token_id = pipe.tokenizer.eos_token_id
pad_length = 1
for j, token in enumerate(output_tokens):
if token == pad_token_id:
output_masks[j] = 0
pad_length = 0
current_length = len(output_tokens)
if current_length < min_length:
pad_length = min_length - current_length
if pad_length > 0:
output_tokens += [pad_token_id] * pad_length
output_weights += [1.0] * pad_length
output_masks += [0] * pad_length
output_masks[-1] = 1
if add_eos_token and output_tokens[-1] != eos_token_id:
output_tokens += [eos_token_id]
output_weights += [1.0]
output_masks += [1]
return output_tokens, output_weights, output_masks