mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
104 lines
4.9 KiB
Python
104 lines
4.9 KiB
Python
# repo: https://huggingface.co/gokaygokay/Flux-Prompt-Enhance
|
|
|
|
import time
|
|
import random
|
|
import threading
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
import gradio as gr
|
|
from modules import shared, scripts_manager, devices, processing
|
|
|
|
|
|
repo_id = "gokaygokay/Flux-Prompt-Enhance"
|
|
num_return_sequences = 5
|
|
load_lock = threading.Lock()
|
|
|
|
|
|
class Script(scripts_manager.Script):
|
|
prompts = [['']]
|
|
tokenizer: AutoTokenizer = None
|
|
model: AutoModelForSeq2SeqLM = None
|
|
prefix: str = "enhance prompt: "
|
|
button: gr.Button = None
|
|
auto_apply: gr.Checkbox = None
|
|
max_length: gr.Slider = None
|
|
temperature: gr.Slider = None
|
|
repetition_penalty: gr.Slider = None
|
|
table: gr.DataFrame = None
|
|
prompt: gr.Textbox = None
|
|
|
|
def title(self):
|
|
return 'Flux Prompt enhance (Legacy)'
|
|
|
|
def show(self, is_img2img):
|
|
return True
|
|
|
|
def load(self):
|
|
with load_lock:
|
|
if self.tokenizer is None:
|
|
self.tokenizer = AutoTokenizer.from_pretrained('gokaygokay/Flux-Prompt-Enhance', cache_dir=shared.opts.hfcache_dir)
|
|
if self.model is None:
|
|
shared.log.info(f'Prompt enhance: model="{repo_id}"')
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained('gokaygokay/Flux-Prompt-Enhance', cache_dir=shared.opts.hfcache_dir).to(device=devices.cpu, dtype=devices.dtype)
|
|
|
|
def enhance(self, prompt, auto_apply: bool = False, temperature: float = 0.7, repetition_penalty: float = 1.2, max_length: int = 128):
|
|
self.load()
|
|
t0 = time.time()
|
|
input_text = self.prefix + prompt
|
|
input_ids = self.tokenizer(input_text, return_tensors="pt").input_ids.to(devices.device)
|
|
self.model = self.model.to(devices.device)
|
|
kwargs = {
|
|
'max_length': int(max_length),
|
|
'num_return_sequences': int(num_return_sequences),
|
|
'do_sample': True,
|
|
'temperature': float(temperature),
|
|
'repetition_penalty': float(repetition_penalty),
|
|
}
|
|
try:
|
|
outputs = self.model.generate(input_ids, **kwargs)
|
|
except Exception as e:
|
|
shared.log.error(f'Prompt enhance: error="{e}"')
|
|
return [['']]
|
|
self.model = self.model.to(devices.cpu)
|
|
prompts = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
prompts = [[p] for p in prompts]
|
|
t1 = time.time()
|
|
shared.log.info(f'Prompt enhance: temperature={temperature} repetition={repetition_penalty} length={max_length} sequences={num_return_sequences} apply={auto_apply} time={t1-t0:.2f}s')
|
|
return prompts
|
|
|
|
def select(self, cell: gr.SelectData, _table):
|
|
prompt = cell.value if hasattr(cell, 'value') else cell
|
|
shared.log.info(f'Prompt enhance: prompt="{prompt}"')
|
|
return prompt
|
|
|
|
def ui(self, _is_img2img):
|
|
with gr.Row():
|
|
self.button = gr.Button(value='Enhance prompt')
|
|
self.auto_apply = gr.Checkbox(label='Auto apply', value=False)
|
|
with gr.Row():
|
|
self.max_length = gr.Slider(label='Length', minimum=64, maximum=512, step=1, value=128)
|
|
self.temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=2.0, step=0.05, value=0.7)
|
|
self.repetition_penalty = gr.Slider(label='Penalty', minimum=0.1, maximum=2.0, step=0.05, value=1.2)
|
|
with gr.Row():
|
|
self.table = gr.DataFrame(self.prompts, label='', show_label=False, interactive=False, wrap=True, datatype="str", col_count=1, headers=['Prompts'])
|
|
|
|
if self.prompt is not None:
|
|
self.button.click(fn=self.enhance, inputs=[self.prompt, self.auto_apply, self.temperature, self.repetition_penalty, self.max_length], outputs=[self.table])
|
|
self.table.select(fn=self.select, inputs=[self.table], outputs=[self.prompt])
|
|
return [self.auto_apply, self.temperature, self.repetition_penalty, self.max_length]
|
|
|
|
def run(self, p: processing.StableDiffusionProcessing, auto_apply, temperature, repetition_penalty, max_length): # pylint: disable=arguments-differ
|
|
if auto_apply:
|
|
p.prompt = shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)
|
|
p.negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)
|
|
shared.prompt_styles.apply_styles_to_extra(p)
|
|
p.styles = []
|
|
shared.log.debug(f'Prompt enhance: source="{p.prompt}"')
|
|
prompts = self.enhance(p.prompt, auto_apply, temperature, repetition_penalty, max_length)
|
|
p.prompt = random.choice(prompts)[0]
|
|
shared.log.debug(f'Prompt enhance: prompt="{p.prompt}"')
|
|
|
|
def after_component(self, component, **kwargs): # searching for actual ui prompt components
|
|
if getattr(component, 'elem_id', '') in ['txt2img_prompt', 'img2img_prompt', 'control_prompt', 'video_prompt']:
|
|
self.prompt = component
|
|
self.prompt.use_original = True
|