mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
1560 lines
57 KiB
Python
1560 lines
57 KiB
Python
## -----------------------------------------------------------------------------
|
|
# Generate unlimited size prompt with weighting for SD3&SDXL&SD15
|
|
# If you use sd_embed in your research, please cite the following work:
|
|
#
|
|
# ```
|
|
# @misc{sd_embed_2024,
|
|
# author = {Shudong Zhu(Andrew Zhu)},
|
|
# title = {Long Prompt Weighted Stable Diffusion Embedding},
|
|
# howpublished = {\url{https://github.com/xhinker/sd_embed}},
|
|
# year = {2024},
|
|
# }
|
|
# ```
|
|
# Author: Andrew Zhu
|
|
# Book: Using Stable Diffusion with Python, https://www.amazon.com/Using-Stable-Diffusion-Python-Generation/dp/1835086373
|
|
# Github: https://github.com/xhinker
|
|
# Medium: https://medium.com/@xhinker
|
|
## -----------------------------------------------------------------------------
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from transformers import CLIPTokenizer, T5Tokenizer
|
|
from diffusers import StableDiffusionPipeline
|
|
from diffusers import StableDiffusionXLPipeline
|
|
from diffusers import StableDiffusion3Pipeline
|
|
from diffusers import FluxPipeline
|
|
from diffusers import ChromaPipeline
|
|
from modules.prompt_parser import parse_prompt_attention # use built-in A1111 parser
|
|
|
|
|
|
def get_prompts_tokens_with_weights(
|
|
clip_tokenizer: CLIPTokenizer
|
|
, prompt: str = None
|
|
):
|
|
"""
|
|
Get prompt token ids and weights, this function works for both prompt and negative prompt
|
|
|
|
Args:
|
|
pipe (CLIPTokenizer)
|
|
A CLIPTokenizer
|
|
prompt (str)
|
|
A prompt string with weights
|
|
|
|
Returns:
|
|
text_tokens (list)
|
|
A list contains token ids
|
|
text_weight (list)
|
|
A list contains the correspodent weight of token ids
|
|
|
|
Example:
|
|
import torch
|
|
from diffusers_plus.tools.sd_embeddings import get_prompts_tokens_with_weights
|
|
from transformers import CLIPTokenizer
|
|
|
|
clip_tokenizer = CLIPTokenizer.from_pretrained(
|
|
"stablediffusionapi/deliberate-v2"
|
|
, subfolder = "tokenizer"
|
|
, dtype = torch.float16
|
|
)
|
|
|
|
token_id_list, token_weight_list = get_prompts_tokens_with_weights(
|
|
clip_tokenizer = clip_tokenizer
|
|
,prompt = "a (red:1.5) cat"*70
|
|
)
|
|
"""
|
|
if (prompt is None) or (len(prompt) < 1):
|
|
prompt = "empty"
|
|
|
|
texts_and_weights = parse_prompt_attention(prompt)
|
|
text_tokens, text_weights = [], []
|
|
for word, weight in texts_and_weights:
|
|
# tokenize and discard the starting and the ending token
|
|
token = clip_tokenizer(
|
|
word
|
|
, truncation=False # so that tokenize whatever length prompt
|
|
).input_ids[1:-1]
|
|
# the returned token is a 1d list: [320, 1125, 539, 320]
|
|
|
|
# merge the new tokens to the all tokens holder: text_tokens
|
|
text_tokens = [*text_tokens, *token]
|
|
|
|
# each token chunk will come with one weight, like ['red cat', 2.0]
|
|
# need to expand weight for each token.
|
|
chunk_weights = [weight] * len(token)
|
|
|
|
# append the weight back to the weight holder: text_weights
|
|
text_weights = [*text_weights, *chunk_weights]
|
|
return text_tokens, text_weights
|
|
|
|
|
|
def get_prompts_tokens_with_weights_t5(
|
|
t5_tokenizer: T5Tokenizer,
|
|
prompt: str,
|
|
add_special_tokens: bool = True
|
|
):
|
|
"""
|
|
Get prompt token ids and weights, this function works for both prompt and negative prompt
|
|
"""
|
|
if (prompt is None) or (len(prompt) < 1):
|
|
prompt = "empty"
|
|
|
|
texts_and_weights = parse_prompt_attention(prompt)
|
|
text_tokens, text_weights, text_masks = [], [], []
|
|
for word, weight in texts_and_weights:
|
|
# tokenize and discard the starting and the ending token
|
|
inputs = t5_tokenizer(
|
|
word,
|
|
truncation=False, # so that tokenize whatever length prompt
|
|
add_special_tokens=add_special_tokens,
|
|
return_length=False,
|
|
)
|
|
|
|
token = inputs.input_ids
|
|
mask = inputs.attention_mask
|
|
|
|
# merge the new tokens to the all tokens holder: text_tokens
|
|
text_tokens = [*text_tokens, *token]
|
|
text_masks = [*text_masks, *mask]
|
|
|
|
# each token chunk will come with one weight, like ['red cat', 2.0]
|
|
# need to expand weight for each token.
|
|
chunk_weights = [weight] * len(token)
|
|
|
|
# append the weight back to the weight holder: text_weights
|
|
text_weights = [*text_weights, *chunk_weights]
|
|
return text_tokens, text_weights, text_masks
|
|
|
|
|
|
def group_tokens_and_weights(
|
|
token_ids: list
|
|
, weights: list
|
|
, pad_last_block=False
|
|
):
|
|
"""
|
|
Produce tokens and weights in groups and pad the missing tokens
|
|
|
|
Args:
|
|
token_ids (list)
|
|
The token ids from tokenizer
|
|
weights (list)
|
|
The weights list from function get_prompts_tokens_with_weights
|
|
pad_last_block (bool)
|
|
Control if fill the last token list to 75 tokens with eos
|
|
Returns:
|
|
new_token_ids (2d list)
|
|
new_weights (2d list)
|
|
|
|
Example:
|
|
from diffusers_plus.tools.sd_embeddings import group_tokens_and_weights
|
|
token_groups,weight_groups = group_tokens_and_weights(
|
|
token_ids = token_id_list
|
|
, weights = token_weight_list
|
|
)
|
|
"""
|
|
bos, eos = 49406, 49407
|
|
|
|
# this will be a 2d list
|
|
new_token_ids = []
|
|
new_weights = []
|
|
while len(token_ids) >= 75:
|
|
# get the first 75 tokens
|
|
head_75_tokens = [token_ids.pop(0) for _ in range(75)]
|
|
head_75_weights = [weights.pop(0) for _ in range(75)]
|
|
|
|
# extract token ids and weights
|
|
temp_77_token_ids = [bos] + head_75_tokens + [eos]
|
|
temp_77_weights = [1.0] + head_75_weights + [1.0]
|
|
|
|
# add 77 token and weights chunk to the holder list
|
|
new_token_ids.append(temp_77_token_ids)
|
|
new_weights.append(temp_77_weights)
|
|
|
|
# padding the left
|
|
if len(token_ids) > 0:
|
|
padding_len = 75 - len(token_ids) if pad_last_block else 0
|
|
|
|
temp_77_token_ids = [bos] + token_ids + [eos] * padding_len + [eos]
|
|
new_token_ids.append(temp_77_token_ids)
|
|
|
|
temp_77_weights = [1.0] + weights + [1.0] * padding_len + [1.0]
|
|
new_weights.append(temp_77_weights)
|
|
|
|
return new_token_ids, new_weights
|
|
|
|
|
|
def get_weighted_text_embeddings_sd15(
|
|
pipe: StableDiffusionPipeline
|
|
, prompt: str = ""
|
|
, neg_prompt: str = ""
|
|
, pad_last_block=False
|
|
, clip_skip: int = 0
|
|
):
|
|
"""
|
|
This function can process long prompt with weights, no length limitation
|
|
for Stable Diffusion v1.5
|
|
|
|
Args:
|
|
pipe (StableDiffusionPipeline)
|
|
prompt (str)
|
|
neg_prompt (str)
|
|
Returns:
|
|
prompt_embeds (torch.Tensor)
|
|
neg_prompt_embeds (torch.Tensor)
|
|
|
|
Example:
|
|
from diffusers import StableDiffusionPipeline
|
|
text2img_pipe = StableDiffusionPipeline.from_pretrained(
|
|
"stablediffusionapi/deliberate-v2"
|
|
, torch_dtype = torch.float16
|
|
, safety_checker = None
|
|
).to("cuda:0")
|
|
prompt_embeds, neg_prompt_embeds = get_weighted_text_embeddings_v15(
|
|
pipe = text2img_pipe
|
|
, prompt = "a (white) cat"
|
|
, neg_prompt = "blur"
|
|
)
|
|
image = text2img_pipe(
|
|
prompt_embeds = prompt_embeds
|
|
, negative_prompt_embeds = neg_prompt_embeds
|
|
, generator = torch.Generator(text2img_pipe.device).manual_seed(2)
|
|
).images[0]
|
|
"""
|
|
original_clip_layers = pipe.text_encoder.text_model.encoder.layers
|
|
if clip_skip > 0:
|
|
pipe.text_encoder.text_model.encoder.layers = original_clip_layers[:-clip_skip]
|
|
|
|
eos = pipe.tokenizer.eos_token_id
|
|
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer, prompt
|
|
)
|
|
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer, neg_prompt
|
|
)
|
|
|
|
# padding the shorter one
|
|
prompt_token_len = len(prompt_tokens)
|
|
neg_prompt_token_len = len(neg_prompt_tokens)
|
|
if prompt_token_len > neg_prompt_token_len:
|
|
# padding the neg_prompt with eos token
|
|
neg_prompt_tokens = (
|
|
neg_prompt_tokens +
|
|
[eos] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
neg_prompt_weights = (
|
|
neg_prompt_weights +
|
|
[1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
else:
|
|
# padding the prompt
|
|
prompt_tokens = (
|
|
prompt_tokens
|
|
+ [eos] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
prompt_weights = (
|
|
prompt_weights
|
|
+ [1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
|
|
embeds = []
|
|
neg_embeds = []
|
|
|
|
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
|
|
prompt_tokens.copy()
|
|
, prompt_weights.copy()
|
|
, pad_last_block=pad_last_block
|
|
)
|
|
|
|
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
|
|
neg_prompt_tokens.copy()
|
|
, neg_prompt_weights.copy()
|
|
, pad_last_block=pad_last_block
|
|
)
|
|
|
|
# get prompt embeddings one by one is not working
|
|
# we must embed prompt group by group
|
|
for i in range(len(prompt_token_groups)):
|
|
# get positive prompt embeddings with weights
|
|
token_tensor = torch.tensor(
|
|
[prompt_token_groups[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder.device
|
|
)
|
|
weight_tensor = torch.tensor(
|
|
prompt_weight_groups[i]
|
|
, dtype=torch.float16
|
|
, device=pipe.text_encoder.device
|
|
)
|
|
|
|
token_embedding = pipe.text_encoder(token_tensor)[0].squeeze(0)
|
|
for j in range(len(weight_tensor)):
|
|
token_embedding[j] = token_embedding[j] * weight_tensor[j]
|
|
token_embedding = token_embedding.unsqueeze(0)
|
|
embeds.append(token_embedding)
|
|
|
|
# get negative prompt embeddings with weights
|
|
neg_token_tensor = torch.tensor(
|
|
[neg_prompt_token_groups[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder.device
|
|
)
|
|
neg_weight_tensor = torch.tensor(
|
|
neg_prompt_weight_groups[i]
|
|
, dtype=torch.float16
|
|
, device=pipe.text_encoder.device
|
|
)
|
|
neg_token_embedding = pipe.text_encoder(neg_token_tensor)[0].squeeze(0)
|
|
for z in range(len(neg_weight_tensor)):
|
|
neg_token_embedding[z] = (
|
|
neg_token_embedding[z] * neg_weight_tensor[z]
|
|
)
|
|
neg_token_embedding = neg_token_embedding.unsqueeze(0)
|
|
neg_embeds.append(neg_token_embedding)
|
|
|
|
prompt_embeds = torch.cat(embeds, dim=1)
|
|
neg_prompt_embeds = torch.cat(neg_embeds, dim=1)
|
|
|
|
# recover clip layers
|
|
if clip_skip > 0:
|
|
pipe.text_encoder.text_model.encoder.layers = original_clip_layers
|
|
|
|
return prompt_embeds, neg_prompt_embeds
|
|
|
|
|
|
def get_weighted_text_embeddings_sdxl(
|
|
pipe: StableDiffusionXLPipeline
|
|
, prompt: str = ""
|
|
, neg_prompt: str = ""
|
|
, pad_last_block=True
|
|
):
|
|
"""
|
|
This function can process long prompt with weights, no length limitation
|
|
for Stable Diffusion XL
|
|
|
|
Args:
|
|
pipe (StableDiffusionPipeline)
|
|
prompt (str)
|
|
neg_prompt (str)
|
|
Returns:
|
|
prompt_embeds (torch.Tensor)
|
|
neg_prompt_embeds (torch.Tensor)
|
|
|
|
Example:
|
|
from diffusers import StableDiffusionPipeline
|
|
text2img_pipe = StableDiffusionPipeline.from_pretrained(
|
|
"stablediffusionapi/deliberate-v2"
|
|
, torch_dtype = torch.float16
|
|
, safety_checker = None
|
|
).to("cuda:0")
|
|
prompt_embeds, neg_prompt_embeds = get_weighted_text_embeddings_v15(
|
|
pipe = text2img_pipe
|
|
, prompt = "a (white) cat"
|
|
, neg_prompt = "blur"
|
|
)
|
|
image = text2img_pipe(
|
|
prompt_embeds = prompt_embeds
|
|
, negative_prompt_embeds = neg_prompt_embeds
|
|
, generator = torch.Generator(text2img_pipe.device).manual_seed(2)
|
|
).images[0]
|
|
"""
|
|
eos = pipe.tokenizer.eos_token_id
|
|
|
|
# tokenizer 1
|
|
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer, prompt
|
|
)
|
|
|
|
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer, neg_prompt
|
|
)
|
|
|
|
# tokenizer 2
|
|
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer_2, prompt
|
|
)
|
|
|
|
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer_2, neg_prompt
|
|
)
|
|
|
|
# padding the shorter one
|
|
prompt_token_len = len(prompt_tokens)
|
|
neg_prompt_token_len = len(neg_prompt_tokens)
|
|
|
|
if prompt_token_len > neg_prompt_token_len:
|
|
# padding the neg_prompt with eos token
|
|
neg_prompt_tokens = (
|
|
neg_prompt_tokens +
|
|
[eos] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
neg_prompt_weights = (
|
|
neg_prompt_weights +
|
|
[1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
else:
|
|
# padding the prompt
|
|
prompt_tokens = (
|
|
prompt_tokens
|
|
+ [eos] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
prompt_weights = (
|
|
prompt_weights
|
|
+ [1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
|
|
# padding the shorter one for token set 2
|
|
prompt_token_len_2 = len(prompt_tokens_2)
|
|
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
|
|
|
|
if prompt_token_len_2 > neg_prompt_token_len_2:
|
|
# padding the neg_prompt with eos token
|
|
neg_prompt_tokens_2 = (
|
|
neg_prompt_tokens_2 +
|
|
[eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
neg_prompt_weights_2 = (
|
|
neg_prompt_weights_2 +
|
|
[1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
else:
|
|
# padding the prompt
|
|
prompt_tokens_2 = (
|
|
prompt_tokens_2
|
|
+ [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
prompt_weights_2 = (
|
|
prompt_weights_2
|
|
+ [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
|
|
embeds = []
|
|
neg_embeds = []
|
|
|
|
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
|
|
prompt_tokens.copy()
|
|
, prompt_weights.copy()
|
|
, pad_last_block=pad_last_block
|
|
)
|
|
|
|
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
|
|
neg_prompt_tokens.copy()
|
|
, neg_prompt_weights.copy()
|
|
, pad_last_block=pad_last_block
|
|
)
|
|
|
|
prompt_token_groups_2, _prompt_weight_groups_2 = group_tokens_and_weights(
|
|
prompt_tokens_2.copy()
|
|
, prompt_weights_2.copy()
|
|
, pad_last_block=pad_last_block
|
|
)
|
|
|
|
neg_prompt_token_groups_2, _neg_prompt_weight_groups_2 = group_tokens_and_weights(
|
|
neg_prompt_tokens_2.copy()
|
|
, neg_prompt_weights_2.copy()
|
|
, pad_last_block=pad_last_block
|
|
)
|
|
|
|
# get prompt embeddings one by one is not working.
|
|
for i in range(len(prompt_token_groups)):
|
|
# get positive prompt embeddings with weights
|
|
token_tensor = torch.tensor(
|
|
[prompt_token_groups[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder.device
|
|
)
|
|
weight_tensor = torch.tensor(
|
|
prompt_weight_groups[i]
|
|
, dtype=torch.float16
|
|
, device=pipe.text_encoder.device
|
|
)
|
|
|
|
token_tensor_2 = torch.tensor(
|
|
[prompt_token_groups_2[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder_2.device
|
|
)
|
|
|
|
# use first text encoder
|
|
prompt_embeds_1 = pipe.text_encoder(
|
|
token_tensor.to(pipe.text_encoder.device)
|
|
, output_hidden_states=True
|
|
)
|
|
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
|
|
|
# use second text encoder
|
|
prompt_embeds_2 = pipe.text_encoder_2(
|
|
token_tensor_2.to(pipe.text_encoder_2.device)
|
|
, output_hidden_states=True
|
|
)
|
|
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
|
pooled_prompt_embeds = prompt_embeds_2[0]
|
|
|
|
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
|
|
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0).to(pipe.text_encoder.device)
|
|
|
|
for j in range(len(weight_tensor)):
|
|
if weight_tensor[j] != 1.0:
|
|
# ow = weight_tensor[j] - 1
|
|
|
|
# optional process
|
|
# To map number of (0,1) to (-1,1)
|
|
# tanh_weight = (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
|
|
# weight = 1 + tanh_weight
|
|
|
|
# add weight method 1:
|
|
# token_embedding[j] = token_embedding[j] * weight
|
|
# token_embedding[j] = (
|
|
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight
|
|
# )
|
|
|
|
# add weight method 2:
|
|
# token_embedding[j] = (
|
|
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
|
|
# )
|
|
|
|
# add weight method 3:
|
|
token_embedding[j] = token_embedding[j] * weight_tensor[j]
|
|
|
|
token_embedding = token_embedding.unsqueeze(0)
|
|
embeds.append(token_embedding)
|
|
|
|
# get negative prompt embeddings with weights
|
|
neg_token_tensor = torch.tensor(
|
|
[neg_prompt_token_groups[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder.device
|
|
)
|
|
neg_token_tensor_2 = torch.tensor(
|
|
[neg_prompt_token_groups_2[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder_2.device
|
|
)
|
|
neg_weight_tensor = torch.tensor(
|
|
neg_prompt_weight_groups[i]
|
|
, dtype=torch.float16
|
|
, device=pipe.text_encoder.device
|
|
)
|
|
|
|
# use first text encoder
|
|
neg_prompt_embeds_1 = pipe.text_encoder(
|
|
neg_token_tensor.to(pipe.text_encoder.device)
|
|
, output_hidden_states=True
|
|
)
|
|
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
|
|
|
|
# use second text encoder
|
|
neg_prompt_embeds_2 = pipe.text_encoder_2(
|
|
neg_token_tensor_2.to(pipe.text_encoder_2.device)
|
|
, output_hidden_states=True
|
|
)
|
|
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
|
|
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
|
|
|
|
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
|
|
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0).to(pipe.text_encoder.device)
|
|
|
|
for z in range(len(neg_weight_tensor)):
|
|
if neg_weight_tensor[z] != 1.0:
|
|
# ow = neg_weight_tensor[z] - 1
|
|
# neg_weight = 1 + (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
|
|
|
|
# add weight method 1:
|
|
# neg_token_embedding[z] = neg_token_embedding[z] * neg_weight
|
|
# neg_token_embedding[z] = (
|
|
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight
|
|
# )
|
|
|
|
# add weight method 2:
|
|
# neg_token_embedding[z] = (
|
|
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
|
|
# )
|
|
|
|
# add weight method 3:
|
|
neg_token_embedding[z] = neg_token_embedding[z] * neg_weight_tensor[z]
|
|
|
|
neg_token_embedding = neg_token_embedding.unsqueeze(0)
|
|
neg_embeds.append(neg_token_embedding)
|
|
|
|
prompt_embeds = torch.cat(embeds, dim=1)
|
|
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
|
|
|
|
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
|
|
|
|
|
def get_weighted_text_embeddings_sdxl_refiner(
|
|
pipe: StableDiffusionXLPipeline
|
|
, prompt: str = ""
|
|
, neg_prompt: str = ""
|
|
):
|
|
"""
|
|
This function can process long prompt with weights, no length limitation
|
|
for Stable Diffusion XL
|
|
|
|
Args:
|
|
pipe (StableDiffusionPipeline)
|
|
prompt (str)
|
|
neg_prompt (str)
|
|
Returns:
|
|
prompt_embeds (torch.Tensor)
|
|
neg_prompt_embeds (torch.Tensor)
|
|
|
|
Example:
|
|
from diffusers import StableDiffusionPipeline
|
|
text2img_pipe = StableDiffusionPipeline.from_pretrained(
|
|
"stablediffusionapi/deliberate-v2"
|
|
, torch_dtype = torch.float16
|
|
, safety_checker = None
|
|
).to("cuda:0")
|
|
prompt_embeds, neg_prompt_embeds = get_weighted_text_embeddings_v15(
|
|
pipe = text2img_pipe
|
|
, prompt = "a (white) cat"
|
|
, neg_prompt = "blur"
|
|
)
|
|
image = text2img_pipe(
|
|
prompt_embeds = prompt_embeds
|
|
, negative_prompt_embeds = neg_prompt_embeds
|
|
, generator = torch.Generator(text2img_pipe.device).manual_seed(2)
|
|
).images[0]
|
|
"""
|
|
eos = 49407 # pipe.tokenizer.eos_token_id
|
|
|
|
# tokenizer 2
|
|
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer_2, prompt
|
|
)
|
|
|
|
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer_2, neg_prompt
|
|
)
|
|
|
|
# padding the shorter one for token set 2
|
|
prompt_token_len_2 = len(prompt_tokens_2)
|
|
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
|
|
|
|
if prompt_token_len_2 > neg_prompt_token_len_2:
|
|
# padding the neg_prompt with eos token
|
|
neg_prompt_tokens_2 = (
|
|
neg_prompt_tokens_2 +
|
|
[eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
neg_prompt_weights_2 = (
|
|
neg_prompt_weights_2 +
|
|
[1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
else:
|
|
# padding the prompt
|
|
prompt_tokens_2 = (
|
|
prompt_tokens_2
|
|
+ [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
prompt_weights_2 = (
|
|
prompt_weights_2
|
|
+ [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
|
|
embeds = []
|
|
neg_embeds = []
|
|
|
|
prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(
|
|
prompt_tokens_2.copy()
|
|
, prompt_weights_2.copy()
|
|
)
|
|
|
|
neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(
|
|
neg_prompt_tokens_2.copy()
|
|
, neg_prompt_weights_2.copy()
|
|
)
|
|
|
|
# get prompt embeddings one by one is not working.
|
|
for i in range(len(prompt_token_groups_2)):
|
|
# get positive prompt embeddings with weights
|
|
token_tensor_2 = torch.tensor(
|
|
[prompt_token_groups_2[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder_2.device
|
|
)
|
|
|
|
weight_tensor_2 = torch.tensor(
|
|
prompt_weight_groups_2[i]
|
|
, dtype=torch.float16
|
|
, device=pipe.text_encoder_2.device
|
|
)
|
|
|
|
# use second text encoder
|
|
prompt_embeds_2 = pipe.text_encoder_2(
|
|
token_tensor_2.to(pipe.text_encoder_2.device)
|
|
, output_hidden_states=True
|
|
)
|
|
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
|
pooled_prompt_embeds = prompt_embeds_2[0]
|
|
|
|
prompt_embeds_list = [prompt_embeds_2_hidden_states]
|
|
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0)
|
|
|
|
for j in range(len(weight_tensor_2)):
|
|
if weight_tensor_2[j] != 1.0:
|
|
# ow = weight_tensor_2[j] - 1
|
|
|
|
# optional process
|
|
# To map number of (0,1) to (-1,1)
|
|
# tanh_weight = (math.exp(ow) / (math.exp(ow) + 1) - 0.5) * 2
|
|
# weight = 1 + tanh_weight
|
|
|
|
# add weight method 1:
|
|
# token_embedding[j] = token_embedding[j] * weight
|
|
# token_embedding[j] = (
|
|
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight
|
|
# )
|
|
|
|
# add weight method 2:
|
|
token_embedding[j] = (
|
|
token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor_2[j]
|
|
)
|
|
|
|
token_embedding = token_embedding.unsqueeze(0)
|
|
embeds.append(token_embedding)
|
|
|
|
# get negative prompt embeddings with weights
|
|
neg_token_tensor_2 = torch.tensor(
|
|
[neg_prompt_token_groups_2[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder_2.device
|
|
)
|
|
neg_weight_tensor_2 = torch.tensor(
|
|
neg_prompt_weight_groups_2[i]
|
|
, dtype=torch.float16
|
|
, device=pipe.text_encoder_2.device
|
|
)
|
|
|
|
# use second text encoder
|
|
neg_prompt_embeds_2 = pipe.text_encoder_2(
|
|
neg_token_tensor_2.to(pipe.text_encoder_2.device)
|
|
, output_hidden_states=True
|
|
)
|
|
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
|
|
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
|
|
|
|
neg_prompt_embeds_list = [neg_prompt_embeds_2_hidden_states]
|
|
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0)
|
|
|
|
for z in range(len(neg_weight_tensor_2)):
|
|
if neg_weight_tensor_2[z] != 1.0:
|
|
# ow = neg_weight_tensor_2[z] - 1
|
|
# neg_weight = 1 + (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
|
|
|
|
# add weight method 1:
|
|
# neg_token_embedding[z] = neg_token_embedding[z] * neg_weight
|
|
# neg_token_embedding[z] = (
|
|
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight
|
|
# )
|
|
|
|
# add weight method 2:
|
|
neg_token_embedding[z] = (
|
|
neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) *
|
|
neg_weight_tensor_2[z]
|
|
)
|
|
|
|
neg_token_embedding = neg_token_embedding.unsqueeze(0)
|
|
neg_embeds.append(neg_token_embedding)
|
|
|
|
prompt_embeds = torch.cat(embeds, dim=1)
|
|
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
|
|
|
|
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
|
|
|
|
|
def get_weighted_text_embeddings_sdxl_2p(
|
|
pipe: StableDiffusionXLPipeline
|
|
, prompt: str = ""
|
|
, prompt_2: str = None
|
|
, neg_prompt: str = ""
|
|
, neg_prompt_2: str = None
|
|
):
|
|
"""
|
|
This function can process long prompt with weights, no length limitation
|
|
for Stable Diffusion XL, support two prompt sets.
|
|
|
|
Args:
|
|
pipe (StableDiffusionPipeline)
|
|
prompt (str)
|
|
neg_prompt (str)
|
|
Returns:
|
|
prompt_embeds (torch.Tensor)
|
|
neg_prompt_embeds (torch.Tensor)
|
|
|
|
Example:
|
|
from diffusers import StableDiffusionPipeline
|
|
text2img_pipe = StableDiffusionPipeline.from_pretrained(
|
|
"stablediffusionapi/deliberate-v2"
|
|
, torch_dtype = torch.float16
|
|
, safety_checker = None
|
|
).to("cuda:0")
|
|
prompt_embeds, neg_prompt_embeds = get_weighted_text_embeddings_v15(
|
|
pipe = text2img_pipe
|
|
, prompt = "a (white) cat"
|
|
, neg_prompt = "blur"
|
|
)
|
|
image = text2img_pipe(
|
|
prompt_embeds = prompt_embeds
|
|
, negative_prompt_embeds = neg_prompt_embeds
|
|
, generator = torch.Generator(text2img_pipe.device).manual_seed(2)
|
|
).images[0]
|
|
"""
|
|
prompt_2 = prompt_2 or prompt
|
|
neg_prompt_2 = neg_prompt_2 or neg_prompt
|
|
eos = pipe.tokenizer.eos_token_id
|
|
|
|
# tokenizer 1
|
|
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer, prompt
|
|
)
|
|
|
|
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer, neg_prompt
|
|
)
|
|
|
|
# tokenizer 2
|
|
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer_2, prompt_2
|
|
)
|
|
|
|
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer_2, neg_prompt_2
|
|
)
|
|
|
|
# padding the shorter one
|
|
prompt_token_len = len(prompt_tokens)
|
|
neg_prompt_token_len = len(neg_prompt_tokens)
|
|
|
|
if prompt_token_len > neg_prompt_token_len:
|
|
# padding the neg_prompt with eos token
|
|
neg_prompt_tokens = (
|
|
neg_prompt_tokens +
|
|
[eos] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
neg_prompt_weights = (
|
|
neg_prompt_weights +
|
|
[1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
else:
|
|
# padding the prompt
|
|
prompt_tokens = (
|
|
prompt_tokens
|
|
+ [eos] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
prompt_weights = (
|
|
prompt_weights
|
|
+ [1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
|
|
# padding the shorter one for token set 2
|
|
prompt_token_len_2 = len(prompt_tokens_2)
|
|
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
|
|
|
|
if prompt_token_len_2 > neg_prompt_token_len_2:
|
|
# padding the neg_prompt with eos token
|
|
neg_prompt_tokens_2 = (
|
|
neg_prompt_tokens_2 +
|
|
[eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
neg_prompt_weights_2 = (
|
|
neg_prompt_weights_2 +
|
|
[1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
else:
|
|
# padding the prompt
|
|
prompt_tokens_2 = (
|
|
prompt_tokens_2
|
|
+ [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
prompt_weights_2 = (
|
|
prompt_weights_2
|
|
+ [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
|
|
# now, need to ensure prompt and prompt_2 has the same lemgth
|
|
prompt_token_len = len(prompt_tokens)
|
|
prompt_token_len_2 = len(prompt_tokens_2)
|
|
if prompt_token_len > prompt_token_len_2:
|
|
prompt_tokens_2 = prompt_tokens_2 + [eos] * abs(prompt_token_len - prompt_token_len_2)
|
|
prompt_weights_2 = prompt_weights_2 + [1.0] * abs(prompt_token_len - prompt_token_len_2)
|
|
else:
|
|
prompt_tokens = prompt_tokens + [eos] * abs(prompt_token_len - prompt_token_len_2)
|
|
prompt_weights = prompt_weights + [1.0] * abs(prompt_token_len - prompt_token_len_2)
|
|
|
|
# now, need to ensure neg_prompt and net_prompt_2 has the same lemgth
|
|
neg_prompt_token_len = len(neg_prompt_tokens)
|
|
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
|
|
if neg_prompt_token_len > neg_prompt_token_len_2:
|
|
neg_prompt_tokens_2 = neg_prompt_tokens_2 + [eos] * abs(neg_prompt_token_len - neg_prompt_token_len_2)
|
|
neg_prompt_weights_2 = neg_prompt_weights_2 + [1.0] * abs(neg_prompt_token_len - neg_prompt_token_len_2)
|
|
else:
|
|
neg_prompt_tokens = neg_prompt_tokens + [eos] * abs(neg_prompt_token_len - neg_prompt_token_len_2)
|
|
neg_prompt_weights = neg_prompt_weights + [1.0] * abs(neg_prompt_token_len - neg_prompt_token_len_2)
|
|
|
|
embeds = []
|
|
neg_embeds = []
|
|
|
|
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
|
|
prompt_tokens.copy()
|
|
, prompt_weights.copy()
|
|
)
|
|
|
|
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
|
|
neg_prompt_tokens.copy()
|
|
, neg_prompt_weights.copy()
|
|
)
|
|
|
|
prompt_token_groups_2, prompt_weight_groups_2 = group_tokens_and_weights(
|
|
prompt_tokens_2.copy()
|
|
, prompt_weights_2.copy()
|
|
)
|
|
|
|
neg_prompt_token_groups_2, neg_prompt_weight_groups_2 = group_tokens_and_weights(
|
|
neg_prompt_tokens_2.copy()
|
|
, neg_prompt_weights_2.copy()
|
|
)
|
|
|
|
# get prompt embeddings one by one is not working.
|
|
for i in range(len(prompt_token_groups)):
|
|
# get positive prompt embeddings with weights
|
|
token_tensor = torch.tensor(
|
|
[prompt_token_groups[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder.device
|
|
)
|
|
weight_tensor = torch.tensor(
|
|
prompt_weight_groups[i]
|
|
, device=pipe.text_encoder.device
|
|
)
|
|
|
|
token_tensor_2 = torch.tensor(
|
|
[prompt_token_groups_2[i]]
|
|
, device=pipe.text_encoder_2.device
|
|
)
|
|
|
|
weight_tensor_2 = torch.tensor(
|
|
prompt_weight_groups_2[i]
|
|
, device=pipe.text_encoder_2.device
|
|
)
|
|
|
|
# use first text encoder
|
|
prompt_embeds_1 = pipe.text_encoder(
|
|
token_tensor.to(pipe.text_encoder.device)
|
|
, output_hidden_states=True
|
|
)
|
|
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
|
|
|
# use second text encoder
|
|
prompt_embeds_2 = pipe.text_encoder_2(
|
|
token_tensor_2.to(pipe.text_encoder_2.device)
|
|
, output_hidden_states=True
|
|
)
|
|
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
|
pooled_prompt_embeds = prompt_embeds_2[0]
|
|
|
|
prompt_embeds_1_hidden_states = prompt_embeds_1_hidden_states.squeeze(0)
|
|
prompt_embeds_2_hidden_states = prompt_embeds_2_hidden_states.squeeze(0)
|
|
|
|
for j in range(len(weight_tensor)):
|
|
if weight_tensor[j] != 1.0:
|
|
prompt_embeds_1_hidden_states[j] = (
|
|
prompt_embeds_1_hidden_states[-1] + (
|
|
prompt_embeds_1_hidden_states[j] - prompt_embeds_1_hidden_states[-1]) * weight_tensor[j]
|
|
)
|
|
|
|
if weight_tensor_2[j] != 1.0:
|
|
prompt_embeds_2_hidden_states[j] = (
|
|
prompt_embeds_2_hidden_states[-1] + (
|
|
prompt_embeds_2_hidden_states[j] - prompt_embeds_2_hidden_states[-1]) * weight_tensor_2[j]
|
|
)
|
|
|
|
prompt_embeds_1_hidden_states = prompt_embeds_1_hidden_states.unsqueeze(0)
|
|
prompt_embeds_2_hidden_states = prompt_embeds_2_hidden_states.unsqueeze(0)
|
|
|
|
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
|
|
token_embedding = torch.cat(prompt_embeds_list, dim=-1)
|
|
|
|
embeds.append(token_embedding)
|
|
|
|
# get negative prompt embeddings with weights
|
|
neg_token_tensor = torch.tensor(
|
|
[neg_prompt_token_groups[i]]
|
|
, device=pipe.text_encoder.device
|
|
)
|
|
neg_token_tensor_2 = torch.tensor(
|
|
[neg_prompt_token_groups_2[i]]
|
|
, device=pipe.text_encoder_2.device
|
|
)
|
|
neg_weight_tensor = torch.tensor(
|
|
neg_prompt_weight_groups[i]
|
|
, device=pipe.text_encoder.device
|
|
)
|
|
neg_weight_tensor_2 = torch.tensor(
|
|
neg_prompt_weight_groups_2[i]
|
|
, device=pipe.text_encoder_2.device
|
|
)
|
|
|
|
# use first text encoder
|
|
neg_prompt_embeds_1 = pipe.text_encoder(
|
|
neg_token_tensor.to(pipe.text_encoder.device)
|
|
, output_hidden_states=True
|
|
)
|
|
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
|
|
|
|
# use second text encoder
|
|
neg_prompt_embeds_2 = pipe.text_encoder_2(
|
|
neg_token_tensor_2.to(pipe.text_encoder_2.device)
|
|
, output_hidden_states=True
|
|
)
|
|
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
|
|
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
|
|
|
|
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1_hidden_states.squeeze(0)
|
|
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2_hidden_states.squeeze(0)
|
|
|
|
for z in range(len(neg_weight_tensor)):
|
|
if neg_weight_tensor[z] != 1.0:
|
|
neg_prompt_embeds_1_hidden_states[z] = (
|
|
neg_prompt_embeds_1_hidden_states[-1] + (
|
|
neg_prompt_embeds_1_hidden_states[z] - neg_prompt_embeds_1_hidden_states[-1]) *
|
|
neg_weight_tensor[z]
|
|
)
|
|
|
|
if neg_weight_tensor_2[z] != 1.0:
|
|
neg_prompt_embeds_2_hidden_states[z] = (
|
|
neg_prompt_embeds_2_hidden_states[-1] + (
|
|
neg_prompt_embeds_2_hidden_states[z] - neg_prompt_embeds_2_hidden_states[-1]) *
|
|
neg_weight_tensor_2[z]
|
|
)
|
|
|
|
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1_hidden_states.unsqueeze(0)
|
|
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2_hidden_states.unsqueeze(0)
|
|
|
|
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
|
|
neg_token_embedding = torch.cat(neg_prompt_embeds_list, dim=-1)
|
|
|
|
neg_embeds.append(neg_token_embedding)
|
|
|
|
prompt_embeds = torch.cat(embeds, dim=1)
|
|
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
|
|
|
|
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
|
|
|
|
|
def get_weighted_text_embeddings_sd3(
|
|
pipe: StableDiffusion3Pipeline
|
|
, prompt: str = ""
|
|
, neg_prompt: str = ""
|
|
, pad_last_block=True
|
|
, use_t5_encoder=True
|
|
):
|
|
"""
|
|
This function can process long prompt with weights, no length limitation
|
|
for Stable Diffusion 3
|
|
|
|
Args:
|
|
pipe (StableDiffusionPipeline)
|
|
prompt (str)
|
|
neg_prompt (str)
|
|
Returns:
|
|
sd3_prompt_embeds (torch.Tensor)
|
|
sd3_neg_prompt_embeds (torch.Tensor)
|
|
pooled_prompt_embeds (torch.Tensor)
|
|
negative_pooled_prompt_embeds (torch.Tensor)
|
|
"""
|
|
eos = pipe.tokenizer.eos_token_id
|
|
|
|
# tokenizer 1
|
|
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer, prompt
|
|
)
|
|
|
|
neg_prompt_tokens, neg_prompt_weights = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer, neg_prompt
|
|
)
|
|
|
|
# tokenizer 2
|
|
prompt_tokens_2, prompt_weights_2 = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer_2, prompt
|
|
)
|
|
|
|
neg_prompt_tokens_2, neg_prompt_weights_2 = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer_2, neg_prompt
|
|
)
|
|
|
|
# tokenizer 3
|
|
prompt_tokens_3, prompt_weights_3, _ = get_prompts_tokens_with_weights_t5(
|
|
pipe.tokenizer_3, prompt
|
|
)
|
|
|
|
neg_prompt_tokens_3, neg_prompt_weights_3, _ = get_prompts_tokens_with_weights_t5(
|
|
pipe.tokenizer_3, neg_prompt
|
|
)
|
|
|
|
# padding the shorter one
|
|
prompt_token_len = len(prompt_tokens)
|
|
neg_prompt_token_len = len(neg_prompt_tokens)
|
|
|
|
if prompt_token_len > neg_prompt_token_len:
|
|
# padding the neg_prompt with eos token
|
|
neg_prompt_tokens = (
|
|
neg_prompt_tokens +
|
|
[eos] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
neg_prompt_weights = (
|
|
neg_prompt_weights +
|
|
[1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
else:
|
|
# padding the prompt
|
|
prompt_tokens = (
|
|
prompt_tokens
|
|
+ [eos] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
prompt_weights = (
|
|
prompt_weights
|
|
+ [1.0] * abs(prompt_token_len - neg_prompt_token_len)
|
|
)
|
|
|
|
# padding the shorter one for token set 2
|
|
prompt_token_len_2 = len(prompt_tokens_2)
|
|
neg_prompt_token_len_2 = len(neg_prompt_tokens_2)
|
|
|
|
if prompt_token_len_2 > neg_prompt_token_len_2:
|
|
# padding the neg_prompt with eos token
|
|
neg_prompt_tokens_2 = (
|
|
neg_prompt_tokens_2 +
|
|
[eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
neg_prompt_weights_2 = (
|
|
neg_prompt_weights_2 +
|
|
[1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
else:
|
|
# padding the prompt
|
|
prompt_tokens_2 = (
|
|
prompt_tokens_2
|
|
+ [eos] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
prompt_weights_2 = (
|
|
prompt_weights_2
|
|
+ [1.0] * abs(prompt_token_len_2 - neg_prompt_token_len_2)
|
|
)
|
|
|
|
embeds = []
|
|
neg_embeds = []
|
|
|
|
prompt_token_groups, prompt_weight_groups = group_tokens_and_weights(
|
|
prompt_tokens.copy()
|
|
, prompt_weights.copy()
|
|
, pad_last_block=pad_last_block
|
|
)
|
|
|
|
neg_prompt_token_groups, neg_prompt_weight_groups = group_tokens_and_weights(
|
|
neg_prompt_tokens.copy()
|
|
, neg_prompt_weights.copy()
|
|
, pad_last_block=pad_last_block
|
|
)
|
|
|
|
prompt_token_groups_2, _prompt_weight_groups_2 = group_tokens_and_weights(
|
|
prompt_tokens_2.copy()
|
|
, prompt_weights_2.copy()
|
|
, pad_last_block=pad_last_block
|
|
)
|
|
|
|
neg_prompt_token_groups_2, _neg_prompt_weight_groups_2 = group_tokens_and_weights(
|
|
neg_prompt_tokens_2.copy()
|
|
, neg_prompt_weights_2.copy()
|
|
, pad_last_block=pad_last_block
|
|
)
|
|
|
|
# get prompt embeddings one by one is not working.
|
|
for i in range(len(prompt_token_groups)):
|
|
# get positive prompt embeddings with weights
|
|
token_tensor = torch.tensor(
|
|
[prompt_token_groups[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder.device
|
|
)
|
|
weight_tensor = torch.tensor(
|
|
prompt_weight_groups[i]
|
|
, dtype=torch.float16
|
|
, device=pipe.text_encoder.device
|
|
)
|
|
|
|
token_tensor_2 = torch.tensor(
|
|
[prompt_token_groups_2[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder_2.device
|
|
)
|
|
|
|
# use first text encoder
|
|
prompt_embeds_1 = pipe.text_encoder(
|
|
token_tensor.to(pipe.text_encoder.device)
|
|
, output_hidden_states=True
|
|
)
|
|
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
|
pooled_prompt_embeds_1 = prompt_embeds_1[0]
|
|
|
|
# use second text encoder
|
|
prompt_embeds_2 = pipe.text_encoder_2(
|
|
token_tensor_2.to(pipe.text_encoder_2.device)
|
|
, output_hidden_states=True
|
|
)
|
|
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
|
pooled_prompt_embeds_2 = prompt_embeds_2[0]
|
|
|
|
prompt_embeds_list = [prompt_embeds_1_hidden_states, prompt_embeds_2_hidden_states]
|
|
token_embedding = torch.concat(prompt_embeds_list, dim=-1).squeeze(0).to(pipe.text_encoder.device)
|
|
|
|
for j in range(len(weight_tensor)):
|
|
if weight_tensor[j] != 1.0:
|
|
# ow = weight_tensor[j] - 1
|
|
|
|
# optional process
|
|
# To map number of (0,1) to (-1,1)
|
|
# tanh_weight = (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
|
|
# weight = 1 + tanh_weight
|
|
|
|
# add weight method 1:
|
|
# token_embedding[j] = token_embedding[j] * weight
|
|
# token_embedding[j] = (
|
|
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight
|
|
# )
|
|
|
|
# add weight method 2:
|
|
# token_embedding[j] = (
|
|
# token_embedding[-1] + (token_embedding[j] - token_embedding[-1]) * weight_tensor[j]
|
|
# )
|
|
|
|
# add weight method 3:
|
|
token_embedding[j] = token_embedding[j] * weight_tensor[j]
|
|
|
|
token_embedding = token_embedding.unsqueeze(0)
|
|
embeds.append(token_embedding)
|
|
|
|
# get negative prompt embeddings with weights
|
|
neg_token_tensor = torch.tensor(
|
|
[neg_prompt_token_groups[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder.device
|
|
)
|
|
neg_token_tensor_2 = torch.tensor(
|
|
[neg_prompt_token_groups_2[i]]
|
|
, dtype=torch.long, device=pipe.text_encoder_2.device
|
|
)
|
|
neg_weight_tensor = torch.tensor(
|
|
neg_prompt_weight_groups[i]
|
|
, dtype=torch.float16
|
|
, device=pipe.text_encoder.device
|
|
)
|
|
|
|
# use first text encoder
|
|
neg_prompt_embeds_1 = pipe.text_encoder(
|
|
neg_token_tensor.to(pipe.text_encoder.device)
|
|
, output_hidden_states=True
|
|
)
|
|
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
|
|
negative_pooled_prompt_embeds_1 = neg_prompt_embeds_1[0]
|
|
|
|
# use second text encoder
|
|
neg_prompt_embeds_2 = pipe.text_encoder_2(
|
|
neg_token_tensor_2.to(pipe.text_encoder_2.device)
|
|
, output_hidden_states=True
|
|
)
|
|
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
|
|
negative_pooled_prompt_embeds_2 = neg_prompt_embeds_2[0]
|
|
|
|
neg_prompt_embeds_list = [neg_prompt_embeds_1_hidden_states, neg_prompt_embeds_2_hidden_states]
|
|
neg_token_embedding = torch.concat(neg_prompt_embeds_list, dim=-1).squeeze(0).to(pipe.text_encoder.device)
|
|
|
|
for z in range(len(neg_weight_tensor)):
|
|
if neg_weight_tensor[z] != 1.0:
|
|
# ow = neg_weight_tensor[z] - 1
|
|
# neg_weight = 1 + (math.exp(ow)/(math.exp(ow) + 1) - 0.5) * 2
|
|
|
|
# add weight method 1:
|
|
# neg_token_embedding[z] = neg_token_embedding[z] * neg_weight
|
|
# neg_token_embedding[z] = (
|
|
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight
|
|
# )
|
|
|
|
# add weight method 2:
|
|
# neg_token_embedding[z] = (
|
|
# neg_token_embedding[-1] + (neg_token_embedding[z] - neg_token_embedding[-1]) * neg_weight_tensor[z]
|
|
# )
|
|
|
|
# add weight method 3:
|
|
neg_token_embedding[z] = neg_token_embedding[z] * neg_weight_tensor[z]
|
|
|
|
neg_token_embedding = neg_token_embedding.unsqueeze(0)
|
|
neg_embeds.append(neg_token_embedding)
|
|
|
|
prompt_embeds = torch.cat(embeds, dim=1)
|
|
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
|
|
|
|
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
|
|
negative_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2],
|
|
dim=-1)
|
|
|
|
if use_t5_encoder and pipe.text_encoder_3:
|
|
# ----------------- generate positive t5 embeddings --------------------
|
|
prompt_tokens_3 = torch.tensor([prompt_tokens_3], dtype=torch.long)
|
|
|
|
t5_prompt_embeds = pipe.text_encoder_3(prompt_tokens_3.to(pipe.text_encoder_3.device))[0].squeeze(0)
|
|
t5_prompt_embeds = t5_prompt_embeds.to(device=pipe.text_encoder_3.device)
|
|
|
|
# add weight to t5 prompt
|
|
for z in range(len(prompt_weights_3)):
|
|
if prompt_weights_3[z] != 1.0:
|
|
t5_prompt_embeds[z] = t5_prompt_embeds[z] * prompt_weights_3[z]
|
|
t5_prompt_embeds = t5_prompt_embeds.unsqueeze(0)
|
|
else:
|
|
t5_prompt_embeds = torch.zeros(1, 4096, dtype=prompt_embeds.dtype).unsqueeze(0)
|
|
t5_prompt_embeds = t5_prompt_embeds.to(device=pipe.text_encoder_3.device)
|
|
|
|
# merge with the clip embedding 1 and clip embedding 2
|
|
clip_prompt_embeds = torch.nn.functional.pad(
|
|
prompt_embeds, (0, t5_prompt_embeds.shape[-1] - prompt_embeds.shape[-1])
|
|
)
|
|
sd3_prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embeds], dim=-2)
|
|
|
|
if use_t5_encoder and pipe.text_encoder_3:
|
|
# ---------------------- get neg t5 embeddings -------------------------
|
|
neg_prompt_tokens_3 = torch.tensor([neg_prompt_tokens_3], dtype=torch.long)
|
|
|
|
t5_neg_prompt_embeds = pipe.text_encoder_3(neg_prompt_tokens_3.to(pipe.text_encoder_3.device))[0].squeeze(0)
|
|
t5_neg_prompt_embeds = t5_neg_prompt_embeds.to(device=pipe.text_encoder_3.device)
|
|
|
|
# add weight to neg t5 embeddings
|
|
for z in range(len(neg_prompt_weights_3)):
|
|
if neg_prompt_weights_3[z] != 1.0:
|
|
t5_neg_prompt_embeds[z] = t5_neg_prompt_embeds[z] * neg_prompt_weights_3[z]
|
|
t5_neg_prompt_embeds = t5_neg_prompt_embeds.unsqueeze(0)
|
|
else:
|
|
t5_neg_prompt_embeds = torch.zeros(1, 4096, dtype=prompt_embeds.dtype).unsqueeze(0)
|
|
t5_neg_prompt_embeds = t5_prompt_embeds.to(device=pipe.text_encoder_3.device)
|
|
|
|
clip_neg_prompt_embeds = torch.nn.functional.pad(
|
|
negative_prompt_embeds, (0, t5_neg_prompt_embeds.shape[-1] - negative_prompt_embeds.shape[-1])
|
|
)
|
|
sd3_neg_prompt_embeds = torch.cat([clip_neg_prompt_embeds, t5_neg_prompt_embeds], dim=-2)
|
|
|
|
# padding
|
|
size_diff = sd3_neg_prompt_embeds.size(1) - sd3_prompt_embeds.size(1)
|
|
# Calculate padding. Format for pad is (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
|
|
# Since we are padding along the second dimension (axis=1), we need (0, 0, padding_top, padding_bottom, 0, 0)
|
|
# Here padding_top will be 0 and padding_bottom will be size_diff
|
|
|
|
# Check if padding is needed
|
|
if size_diff > 0:
|
|
padding = (0, 0, 0, abs(size_diff), 0, 0)
|
|
sd3_prompt_embeds = F.pad(sd3_prompt_embeds, padding)
|
|
elif size_diff < 0:
|
|
padding = (0, 0, 0, abs(size_diff), 0, 0)
|
|
sd3_neg_prompt_embeds = F.pad(sd3_neg_prompt_embeds, padding)
|
|
|
|
return sd3_prompt_embeds, sd3_neg_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
|
|
|
|
|
def get_weighted_text_embeddings_flux1(
|
|
pipe: FluxPipeline
|
|
, prompt: str = ""
|
|
, prompt2: str = None
|
|
, device=None
|
|
):
|
|
"""
|
|
This function can process long prompt with weights for flux1 model
|
|
|
|
Args:
|
|
|
|
Returns:
|
|
|
|
"""
|
|
prompt2 = prompt if prompt2 is None else prompt2
|
|
if device is None:
|
|
device = pipe.text_encoder.device
|
|
|
|
# tokenizer 1 - openai/clip-vit-large-patch14
|
|
prompt_tokens, prompt_weights = get_prompts_tokens_with_weights(
|
|
pipe.tokenizer, prompt
|
|
)
|
|
|
|
# tokenizer 2 - google/t5-v1_1-xxl
|
|
prompt_tokens_2, prompt_weights_2, _ = get_prompts_tokens_with_weights_t5(
|
|
pipe.tokenizer_2, prompt2
|
|
)
|
|
|
|
prompt_token_groups, _prompt_weight_groups = group_tokens_and_weights(
|
|
prompt_tokens.copy()
|
|
, prompt_weights.copy()
|
|
, pad_last_block=True
|
|
)
|
|
|
|
# # get positive prompt embeddings, flux1 use only text_encoder 1 pooled embeddings
|
|
# token_tensor = torch.tensor(
|
|
# [prompt_token_groups[0]]
|
|
# , dtype = torch.long, device = device
|
|
# )
|
|
# # use first text encoder
|
|
# prompt_embeds_1 = pipe.text_encoder(
|
|
# token_tensor.to(device)
|
|
# , output_hidden_states = False
|
|
# )
|
|
# pooled_prompt_embeds_1 = prompt_embeds_1.pooler_output
|
|
# prompt_embeds = pooled_prompt_embeds_1.to(dtype = pipe.text_encoder.dtype, device = device)
|
|
|
|
# use avg pooling embeddings
|
|
pool_embeds_list = []
|
|
for token_group in prompt_token_groups:
|
|
token_tensor = torch.tensor(
|
|
[token_group]
|
|
, dtype=torch.long
|
|
, device=device
|
|
)
|
|
prompt_embeds_1 = pipe.text_encoder(
|
|
token_tensor.to(device)
|
|
, output_hidden_states=False
|
|
)
|
|
pooled_prompt_embeds = prompt_embeds_1.pooler_output.squeeze(0)
|
|
pool_embeds_list.append(pooled_prompt_embeds)
|
|
|
|
prompt_embeds = torch.stack(pool_embeds_list, dim=0)
|
|
|
|
# get the avg pool
|
|
prompt_embeds = prompt_embeds.mean(dim=0, keepdim=True)
|
|
# prompt_embeds = prompt_embeds.unsqueeze(0)
|
|
prompt_embeds = prompt_embeds.to(dtype=pipe.text_encoder.dtype, device=device)
|
|
|
|
# generate positive t5 embeddings
|
|
prompt_tokens_2 = torch.tensor([prompt_tokens_2], dtype=torch.long)
|
|
|
|
t5_prompt_embeds = pipe.text_encoder_2(prompt_tokens_2.to(device))[0].squeeze(0)
|
|
t5_prompt_embeds = t5_prompt_embeds.to(device=device)
|
|
|
|
# add weight to t5 prompt
|
|
for z in range(len(prompt_weights_2)):
|
|
if prompt_weights_2[z] != 1.0:
|
|
t5_prompt_embeds[z] = t5_prompt_embeds[z] * prompt_weights_2[z]
|
|
t5_prompt_embeds = t5_prompt_embeds.unsqueeze(0)
|
|
|
|
t5_prompt_embeds = t5_prompt_embeds.to(dtype=pipe.text_encoder_2.dtype, device=device)
|
|
|
|
return t5_prompt_embeds, prompt_embeds
|
|
|
|
|
|
def get_weighted_text_embeddings_chroma(
|
|
pipe: ChromaPipeline,
|
|
prompt: str = "",
|
|
neg_prompt: str = "",
|
|
device=None
|
|
):
|
|
"""
|
|
This function can process long prompt with weights for Chroma model
|
|
|
|
Args:
|
|
pipe (ChromaPipeline)
|
|
prompt (str)
|
|
neg_prompt (str)
|
|
device (torch.device, optional): Device to run the embeddings on.
|
|
Returns:
|
|
prompt_embeds (torch.Tensor)
|
|
prompt_attention_mask (torch.Tensor)
|
|
neg_prompt_embeds (torch.Tensor)
|
|
neg_prompt_attention_mask (torch.Tensor)
|
|
"""
|
|
if device is None:
|
|
device = pipe.text_encoder.device
|
|
|
|
dtype = pipe.text_encoder.dtype
|
|
|
|
prompt_tokens, prompt_weights, prompt_masks = get_prompts_tokens_with_weights_t5(
|
|
pipe.tokenizer, prompt, add_special_tokens=False
|
|
)
|
|
|
|
neg_prompt_tokens, neg_prompt_weights, neg_prompt_masks = get_prompts_tokens_with_weights_t5(
|
|
pipe.tokenizer, neg_prompt, add_special_tokens=False
|
|
)
|
|
|
|
prompt_tokens, prompt_weights, prompt_masks = pad_prompt_tokens_to_length_chroma(
|
|
pipe,
|
|
prompt_tokens,
|
|
prompt_weights,
|
|
prompt_masks
|
|
)
|
|
|
|
prompt_embeds, prompt_masks = get_weighted_prompt_embeds_with_attention_mask_chroma(
|
|
pipe,
|
|
prompt_tokens,
|
|
prompt_weights,
|
|
prompt_masks,
|
|
device=device,
|
|
dtype=dtype)
|
|
|
|
neg_prompt_tokens, neg_prompt_weights, neg_prompt_masks = pad_prompt_tokens_to_length_chroma(
|
|
pipe,
|
|
neg_prompt_tokens,
|
|
neg_prompt_weights,
|
|
neg_prompt_masks
|
|
)
|
|
|
|
neg_prompt_embeds, neg_prompt_masks = get_weighted_prompt_embeds_with_attention_mask_chroma(
|
|
pipe,
|
|
neg_prompt_tokens,
|
|
neg_prompt_weights,
|
|
neg_prompt_masks,
|
|
device=device,
|
|
dtype=dtype)
|
|
# debug, will be removed later
|
|
|
|
return prompt_embeds, prompt_masks, neg_prompt_embeds, neg_prompt_masks
|
|
|
|
|
|
def get_weighted_prompt_embeds_with_attention_mask_chroma(
|
|
pipe: ChromaPipeline,
|
|
tokens,
|
|
weights,
|
|
masks,
|
|
device,
|
|
dtype
|
|
):
|
|
prompt_tokens = torch.tensor([tokens], dtype=torch.long, device=device)
|
|
prompt_masks = torch.tensor([masks], dtype=torch.long, device=device)
|
|
prompt_embeds = pipe.text_encoder(prompt_tokens, output_hidden_states=False, attention_mask=prompt_masks)[0].squeeze(0)
|
|
for z in range(len(weights)):
|
|
if weights[z] != 1.0:
|
|
prompt_embeds[z] = prompt_embeds[z] * weights[z]
|
|
prompt_embeds = prompt_embeds.unsqueeze(0).to(dtype=dtype, device=device)
|
|
return prompt_embeds, prompt_masks
|
|
|
|
|
|
def pad_prompt_tokens_to_length_chroma(pipe, input_tokens, input_weights, input_masks, min_length=5, add_eos_token=True):
|
|
"""
|
|
Implementation of Chroma's padding for prompt embeddings.
|
|
Pads the embeddings to the maximum length found in the batch, while ensuring
|
|
that the padding tokens are masked correctly while keeping at least one padding and one eos token unmasked.
|
|
|
|
https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training
|
|
"""
|
|
|
|
output_tokens = input_tokens.copy()
|
|
output_weights = input_weights.copy()
|
|
output_masks = input_masks.copy()
|
|
|
|
pad_token_id = pipe.tokenizer.pad_token_id
|
|
eos_token_id = pipe.tokenizer.eos_token_id
|
|
|
|
pad_length = 1
|
|
|
|
for j, token in enumerate(output_tokens):
|
|
if token == pad_token_id:
|
|
output_masks[j] = 0
|
|
pad_length = 0
|
|
|
|
current_length = len(output_tokens)
|
|
|
|
if current_length < min_length:
|
|
pad_length = min_length - current_length
|
|
|
|
if pad_length > 0:
|
|
output_tokens += [pad_token_id] * pad_length
|
|
output_weights += [1.0] * pad_length
|
|
output_masks += [0] * pad_length
|
|
|
|
output_masks[-1] = 1
|
|
|
|
if add_eos_token and output_tokens[-1] != eos_token_id:
|
|
output_tokens += [eos_token_id]
|
|
output_weights += [1.0]
|
|
output_masks += [1]
|
|
|
|
return output_tokens, output_weights, output_masks
|