mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
94 lines
4.1 KiB
Python
94 lines
4.1 KiB
Python
import gradio as gr
|
|
import torch
|
|
from modules import shared, devices, scripts, processing, sd_models
|
|
|
|
|
|
checked_ok = False
|
|
|
|
|
|
def check_dependencies():
|
|
global checked_ok # pylint: disable=global-statement
|
|
from installer import installed, install
|
|
packages = [
|
|
('ligo-segments', 'ligo-segments'),
|
|
]
|
|
for pkg in packages:
|
|
if not installed(pkg[1], reload=True, quiet=True):
|
|
install(pkg[0], pkg[1], ignore=False)
|
|
try:
|
|
from ligo.segments import segment # pylint: disable=unused-import
|
|
checked_ok = True
|
|
return True
|
|
except Exception as e:
|
|
shared.log.error(f'Mixture tiling: {e}')
|
|
return False
|
|
|
|
|
|
class Script(scripts.Script):
|
|
def title(self):
|
|
return 'Mixture Tiling: Scene Composition'
|
|
|
|
def show(self, is_img2img):
|
|
return not is_img2img if shared.native else False
|
|
|
|
def ui(self, _is_img2img):
|
|
with gr.Row():
|
|
gr.HTML('<a href="https://arxiv.org/abs/2302.02412">  Mixture Tiling: Scene Composition</a><br>')
|
|
with gr.Row():
|
|
gr.HTML('<span>  Separated prompts using new lines<br>  Number of prompts must matcxh X*Y</span>')
|
|
with gr.Row():
|
|
x_size = gr.Slider(label='X components', minimum=1, maximum=5, step=1, value=1)
|
|
y_size = gr.Slider(label='Y components', minimum=1, maximum=5, step=1, value=1)
|
|
with gr.Row():
|
|
x_overlap = gr.Slider(label='X overlap', minimum=0, maximum=1, step=0.01, value=0.5)
|
|
y_overlap = gr.Slider(label='Y overlap', minimum=0, maximum=1, step=0.01, value=0.5)
|
|
return x_size, y_size, x_overlap, y_overlap
|
|
|
|
def run(self, p: processing.StableDiffusionProcessing, x_size, y_size, x_overlap, y_overlap): # pylint: disable=arguments-differ
|
|
if not checked_ok:
|
|
if not check_dependencies():
|
|
return
|
|
prompts = p.prompt.splitlines()
|
|
if len(prompts) != x_size * y_size:
|
|
shared.log.error(f'Mixture tiling prompt count mismatch: prompts={len(prompts)} required={x_size * y_size}')
|
|
return
|
|
# backup pipeline and params
|
|
orig_pipeline = shared.sd_model
|
|
orig_dtype = devices.dtype
|
|
orig_prompt_attention = shared.opts.prompt_attention
|
|
# create pipeline
|
|
if shared.sd_model_type != 'sd':
|
|
shared.log.error(f'Mixture tiling: incorrect base model: {shared.sd_model.__class__.__name__}')
|
|
return
|
|
shared.sd_model = sd_models.switch_pipe('mixture_tiling', shared.sd_model)
|
|
if shared.sd_model.__class__.__name__ != 'StableDiffusionTilingPipeline': # switch failed
|
|
shared.log.error(f'Mixture tiling: not a tiling pipeline: {shared.sd_model.__class__.__name__}')
|
|
shared.sd_model = orig_pipeline
|
|
return
|
|
sd_models.set_diffuser_options(shared.sd_model)
|
|
shared.opts.data['prompt_attention'] = 'fixed' # this pipeline is not compatible with embeds
|
|
shared.sd_model.to(torch.float32) # this pipeline unet is not compatible with fp16
|
|
processing.fix_seed(p)
|
|
# set pipeline specific params, note that standard params are applied when applicable
|
|
y_prompts = []
|
|
for y in range(y_size):
|
|
x_prompts = []
|
|
for x in range(x_size):
|
|
x_prompts.append(prompts[y * x_size + x])
|
|
y_prompts.append(x_prompts)
|
|
p.task_args['prompt'] = y_prompts
|
|
p.task_args['seed'] = p.seed
|
|
p.task_args['tile_width'] = p.height
|
|
p.task_args['tile_height'] = p.width
|
|
p.task_args['tile_col_overlap'] = int(p.height * x_overlap)
|
|
p.task_args['tile_row_overlap'] = int(p.width * y_overlap)
|
|
p.task_args['output_type'] = 'np'
|
|
# run pipeline
|
|
shared.log.debug(f'Tiling: args={p.task_args}')
|
|
processed: processing.Processed = processing.process_images(p) # runs processing using main loop
|
|
# restore pipeline and params
|
|
shared.opts.data['prompt_attention'] = orig_prompt_attention
|
|
shared.sd_model = orig_pipeline
|
|
shared.sd_model.to(orig_dtype)
|
|
return processed
|