import gradio as gr
import torch
import numpy as np
import diffusers
from modules import scripts_manager, processing, shared, devices
handler = None
zts = None
supported_model_list = ['sdxl']
orig_prompt_attention = None
class Script(scripts_manager.Script):
def title(self):
return 'Style Aligned Image Generation'
def show(self, is_img2img):
return True
def reset(self):
global handler, zts # pylint: disable=global-statement
handler = None
zts = None
shared.log.info('SA: image upload')
def preset(self, preset):
if preset == 'text':
return [['attention', 'adain_queries', 'adain_keys'], 1.0, 0, 0.0]
elif preset == 'image':
return [['group_norm', 'layer_norm', 'attention', 'adain_queries', 'adain_keys'], 1.0, 2, 0.0]
else:
return [['group_norm', 'layer_norm', 'attention', 'adain_queries', 'adain_keys', 'adain_values', 'full_attention_share'], 1.0, 1, 0.5]
def ui(self, _is_img2img): # ui elements
with gr.Row():
gr.HTML('  Style Aligned Image Generation
')
with gr.Row():
preset = gr.Dropdown(label="Preset", choices=['text', 'image', 'all'], value='text')
scheduler = gr.Checkbox(label="Override scheduler", value=False)
with gr.Row():
shared_opts = gr.Dropdown(label="Shared options",
multiselect=True,
choices=['group_norm', 'layer_norm', 'attention', 'adain_queries', 'adain_keys', 'adain_values', 'full_attention_share'],
value=['attention', 'adain_queries', 'adain_keys'],
)
with gr.Row():
shared_score_scale = gr.Slider(label="Scale", minimum=0.0, maximum=2.0, step=0.01, value=1.0)
shared_score_shift = gr.Slider(label="Shift", minimum=0, maximum=10, step=1, value=0)
only_self_level = gr.Slider(label="Level", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
with gr.Row():
prompt = gr.Textbox(lines=1, label='Optional image description', placeholder='use the style from the image')
with gr.Row():
image = gr.Image(label='Optional image', type='pil')
image.change(self.reset)
preset.change(self.preset, inputs=[preset], outputs=[shared_opts, shared_score_scale, shared_score_shift, only_self_level])
return [image, prompt, scheduler, shared_opts, shared_score_scale, shared_score_shift, only_self_level]
def run(self, p: processing.StableDiffusionProcessing, image, prompt, scheduler, shared_opts, shared_score_scale, shared_score_shift, only_self_level): # pylint: disable=arguments-differ
global handler, zts, orig_prompt_attention # pylint: disable=global-statement
if shared.sd_model_type not in supported_model_list:
shared.log.warning(f'SA: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}')
return None
from scripts.style_aligned import sa_handler, inversion # pylint: disable=no-name-in-module
handler = sa_handler.Handler(shared.sd_model)
sa_args = sa_handler.StyleAlignedArgs(
share_group_norm='group_norm' in shared_opts,
share_layer_norm='layer_norm' in shared_opts,
share_attention='attention' in shared_opts,
adain_queries='adain_queries' in shared_opts,
adain_keys='adain_keys' in shared_opts,
adain_values='adain_values' in shared_opts,
full_attention_share='full_attention_share' in shared_opts,
shared_score_scale=float(shared_score_scale),
shared_score_shift=np.log(shared_score_shift) if shared_score_shift > 0 else 0,
only_self_level=1 if only_self_level else 0,
)
handler.register(sa_args)
if scheduler:
shared.sd_model.scheduler = diffusers.DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
p.sampler_name = 'None'
if image is not None and zts is None:
shared.log.info(f'SA: inversion image={image} prompt="{prompt}"')
image = image.resize((1024, 1024))
x0 = np.array(image).astype(np.float32) / 255.0
shared.sd_model.scheduler = diffusers.DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
zts = inversion.ddim_inversion(shared.sd_model, x0, prompt, num_inference_steps=50, guidance_scale=2)
p.prompt = p.prompt.splitlines()
p.batch_size = len(p.prompt)
orig_prompt_attention = shared.opts.prompt_attention
shared.opts.data['prompt_attention'] = 'fixed' # otherwise need to deal with class_tokens_mask
if zts is not None:
processing.fix_seed(p)
zT, inversion_callback = inversion.make_inversion_callback(zts, offset=0)
generator = torch.Generator(device='cpu')
generator.manual_seed(p.seed)
latents = torch.randn(p.batch_size, 4, 128, 128, device='cpu', generator=generator, dtype=devices.dtype,).to(devices.device)
latents[0] = zT
p.task_args['latents'] = latents
p.task_args['callback_on_step_end'] = inversion_callback
shared.log.info(f'SA: batch={p.batch_size} type={"image" if zts is not None else "text"} config={sa_args.__dict__}')
return None
def after(self, p: processing.StableDiffusionProcessing, *args): # pylint: disable=unused-argument
global handler # pylint: disable=global-statement
if handler is not None:
handler.remove()
handler = None
shared.opts.data['prompt_attention'] = orig_prompt_attention