mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
183 lines
9.6 KiB
Python
183 lines
9.6 KiB
Python
import os
|
|
import torch
|
|
from modules import shared, errors, timer, prompt_parser_diffusers
|
|
|
|
|
|
debug_enabled = os.environ.get('SD_PROMPT_DEBUG', None) is not None
|
|
debug_log = shared.log.trace if debug_enabled else lambda *args, **kwargs: None
|
|
|
|
|
|
def fix_prompt_batch(p, prompts, negative_prompts, prompts_2, negative_prompts_2):
|
|
if hasattr(p, 'keep_prompts'):
|
|
return prompts, negative_prompts, prompts_2, negative_prompts_2
|
|
|
|
if type(prompts) is str:
|
|
prompts = [prompts]
|
|
if type(negative_prompts) is str:
|
|
negative_prompts = [negative_prompts]
|
|
|
|
if hasattr(p, '[init_images]') and p.init_images is not None and len(p.init_images) > 1:
|
|
while len(prompts) < len(p.init_images):
|
|
prompts.append(prompts[-1])
|
|
while len(negative_prompts) < len(p.init_images):
|
|
negative_prompts.append(negative_prompts[-1])
|
|
|
|
while len(prompts) < p.batch_size:
|
|
prompts.append(prompts[-1])
|
|
while len(negative_prompts) < p.batch_size:
|
|
negative_prompts.append(negative_prompts[-1])
|
|
|
|
while len(negative_prompts) < len(prompts):
|
|
negative_prompts.append(negative_prompts[-1])
|
|
while len(prompts) < len(negative_prompts):
|
|
prompts.append(prompts[-1])
|
|
|
|
if type(prompts_2) is str:
|
|
prompts_2 = [prompts_2]
|
|
if type(prompts_2) is list:
|
|
while len(prompts_2) < len(prompts):
|
|
prompts_2.append(prompts_2[-1])
|
|
if type(negative_prompts_2) is str:
|
|
negative_prompts_2 = [negative_prompts_2]
|
|
if type(negative_prompts_2) is list:
|
|
while len(negative_prompts_2) < len(prompts_2):
|
|
negative_prompts_2.append(negative_prompts_2[-1])
|
|
return prompts, negative_prompts, prompts_2, negative_prompts_2
|
|
|
|
|
|
def fix_prompt_model(cls, prompts, negative_prompts, prompts_2, negative_prompts_2):
|
|
if 'OmniGen' in cls:
|
|
prompts = [p.replace('|image|', '<img><|image_1|></img>') for p in prompts]
|
|
if 'PixArtSigmaPipeline' in cls: # pixart-sigma pipeline throws list-of-list for negative prompt
|
|
negative_prompts = negative_prompts[0]
|
|
return prompts, negative_prompts, prompts_2, negative_prompts_2
|
|
|
|
|
|
def set_fallback_prompt(args: dict, possible: list[str], prompts, negative_prompts, prompts_2, negative_prompts_2) -> dict:
|
|
if ('prompt' in possible) and ('prompt' not in args) and (prompts is not None) and len(prompts) > 0:
|
|
debug_log(f'Prompt fallback: prompt={prompts}')
|
|
args['prompt'] = prompts
|
|
if ('negative_prompt' in possible) and ('negative_prompt' not in args) and (negative_prompts is not None) and len(negative_prompts) > 0:
|
|
debug_log(f'Prompt fallback: negative_prompt={negative_prompts}')
|
|
args['negative_prompt'] = negative_prompts
|
|
if ('prompt_2' in possible) and ('prompt_2' not in args) and (prompts_2 is not None) and len(prompts_2) > 0:
|
|
debug_log(f'Prompt fallback: prompt_2={prompts_2}')
|
|
args['prompt_2'] = prompts_2
|
|
if ('negative_prompt_2' in possible) and ('negative_prompt_2' not in args) and (negative_prompts_2 is not None) and len(negative_prompts_2) > 0:
|
|
debug_log(f'Prompt fallback: negative_prompt_2={negative_prompts_2}')
|
|
args['negative_prompt_2'] = negative_prompts_2
|
|
return args
|
|
|
|
|
|
def set_prompt(p,
|
|
args: dict,
|
|
possible: list[str],
|
|
cls: str,
|
|
prompt_attention: str,
|
|
steps: int,
|
|
clip_skip: int,
|
|
prompts: list[str],
|
|
negative_prompts: list[str],
|
|
prompts_2: list[str],
|
|
negative_prompts_2: list[str],
|
|
) -> dict:
|
|
prompt_attention = prompt_attention or shared.opts.prompt_attention
|
|
if (prompt_attention != 'fixed') and ('Onnx' not in cls) and ('prompt' not in p.task_args) and (
|
|
('StableDiffusion' in cls) or
|
|
('StableCascade' in cls) or
|
|
('Flux' in cls and 'Flux2' not in cls) or
|
|
('Chroma' in cls) or
|
|
('HiDreamImagePipeline' in cls)
|
|
):
|
|
jobid = shared.state.begin('TE Encode')
|
|
try:
|
|
prompt_parser_diffusers.embedder = prompt_parser_diffusers.PromptEmbedder(prompts, negative_prompts, steps, clip_skip, p)
|
|
except Exception as e:
|
|
prompt_parser_diffusers.embedder = None
|
|
shared.log.error(f'Prompt parser encode: {e}')
|
|
if debug_enabled:
|
|
errors.display(e, 'Prompt parser encode')
|
|
timer.process.record('prompt', reset=False)
|
|
shared.state.end(jobid)
|
|
else:
|
|
prompt_parser_diffusers.embedder = None
|
|
prompt_attention = 'fixed'
|
|
|
|
prompts, negative_prompts, prompts_2, negative_prompts_2 = fix_prompt_batch(p, prompts, negative_prompts, prompts_2, negative_prompts_2)
|
|
prompts, negative_prompts, prompts_2, negative_prompts_2 = fix_prompt_model(cls, prompts, negative_prompts, prompts_2, negative_prompts_2)
|
|
|
|
if prompt_parser_diffusers.embedder is not None:
|
|
if 'prompt' in possible:
|
|
debug_log(f'Prompt set embeds: positive={prompts}')
|
|
prompt_embeds = prompt_parser_diffusers.embedder('prompt_embeds')
|
|
prompt_pooled_embeds = prompt_parser_diffusers.embedder('positive_pooleds')
|
|
prompt_attention_masks = prompt_parser_diffusers.embedder('prompt_attention_masks')
|
|
|
|
if prompt_embeds is None:
|
|
shared.log.warning('Prompt parser encode: empty prompt embeds')
|
|
prompt_parser_diffusers.embedder = None
|
|
args = set_fallback_prompt(args, possible, prompts=prompts, negative_prompts=None, prompts_2=None, negative_prompts_2=None)
|
|
prompt_attention = 'fixed'
|
|
elif prompt_embeds.device == torch.device('meta'):
|
|
shared.log.warning('Prompt parser encode: embeds on meta device')
|
|
prompt_parser_diffusers.embedder = None
|
|
args = set_fallback_prompt(args, possible, prompts=prompts, negative_prompts=None, prompts_2=None, negative_prompts_2=None)
|
|
prompt_attention = 'fixed'
|
|
else:
|
|
if 'prompt_embeds' in possible:
|
|
args['prompt_embeds'] = prompt_embeds
|
|
else:
|
|
args = set_fallback_prompt(args, possible, prompts=prompts, negative_prompts=None, prompts_2=None, negative_prompts_2=None)
|
|
if 'pooled_prompt_embeds' in possible:
|
|
args['pooled_prompt_embeds'] = prompt_pooled_embeds
|
|
if 'StableCascade' in cls:
|
|
args['prompt_embeds_pooled'] = prompt_pooled_embeds.unsqueeze(0)
|
|
if 'HiDreamImage' in cls:
|
|
args['prompt_embeds_t5'] = prompt_embeds[0]
|
|
args['prompt_embeds_llama3'] = prompt_embeds[1]
|
|
if 'prompt_attention_mask' in possible:
|
|
args['prompt_attention_mask'] = prompt_attention_masks
|
|
|
|
if 'negative_prompt' in possible:
|
|
debug_log(f'Prompt set embeds: negative={negative_prompts}')
|
|
negative_embeds = prompt_parser_diffusers.embedder('negative_prompt_embeds')
|
|
negative_pooled_embeds = prompt_parser_diffusers.embedder('negative_pooleds')
|
|
negative_attention_masks = prompt_parser_diffusers.embedder('negative_prompt_attention_masks')
|
|
|
|
if negative_embeds is None:
|
|
shared.log.warning('Prompt parser encode: empty negative prompt embeds')
|
|
prompt_parser_diffusers.embedder = None
|
|
args = set_fallback_prompt(args, possible, prompts=None, negative_prompts=negative_prompts, prompts_2=None, negative_prompts_2=None)
|
|
prompt_attention = 'fixed'
|
|
elif negative_embeds.device == torch.device('meta'):
|
|
shared.log.warning('Prompt parser encode: negative embeds on meta device')
|
|
prompt_parser_diffusers.embedder = None
|
|
args = set_fallback_prompt(args, possible, prompts=None, negative_prompts=negative_prompts, prompts_2=None, negative_prompts_2=None)
|
|
prompt_attention = 'fixed'
|
|
else:
|
|
if 'negative_prompt_embeds' in possible:
|
|
args['negative_prompt_embeds'] = negative_embeds
|
|
else:
|
|
args = set_fallback_prompt(args, possible, prompts=None, negative_prompts=negative_prompts, prompts_2=None, negative_prompts_2=None)
|
|
if 'negative_pooled_prompt_embeds' in possible:
|
|
args['negative_pooled_prompt_embeds'] = negative_pooled_embeds
|
|
if 'StableCascade' in cls:
|
|
args['negative_prompt_embeds_pooled'] = negative_pooled_embeds.unsqueeze(0)
|
|
if 'HiDreamImage' in cls:
|
|
args['negative_prompt_embeds_t5'] = negative_embeds[0]
|
|
args['negative_prompt_embeds_llama3'] = negative_embeds[1]
|
|
if 'negative_prompt_attention_mask' in possible:
|
|
args['negative_prompt_attention_mask'] = negative_attention_masks
|
|
else:
|
|
debug_log('Prompt fallback: no embedder')
|
|
args = set_fallback_prompt(args, possible, prompts=prompts, negative_prompts=negative_prompts, prompts_2=None, negative_prompts_2=None)
|
|
prompt_attention = 'fixed'
|
|
|
|
if 'prompt_embeds' not in args and 'negative_prompt_embeds' not in args: # pass secondary prompts as-in
|
|
args = set_fallback_prompt(args, possible, prompts=None, negative_prompts=None, prompts_2=prompts_2, negative_prompts_2=negative_prompts_2)
|
|
|
|
if (prompt_parser_diffusers.embedder is not None) and (not prompt_parser_diffusers.embedder.scheduled_prompt):
|
|
prompt_parser_diffusers.embedder = None # not scheduled so we dont need it anymore
|
|
|
|
return prompt_attention, args
|