1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/scripts/tiling.py
Vladimir Mandic e8b5ea3847 major refactor: remove backend original
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-07-05 13:16:46 -04:00

105 lines
4.7 KiB
Python

from typing import Optional
import torch
import gradio as gr
from PIL import Image
from diffusers.models.lora import LoRACompatibleConv
from torch import Tensor
from torch.nn import functional as F
from torch.nn.modules.utils import _pair
from modules import scripts_manager, processing, shared
modex = 'constant'
modey = 'constant'
def asymmetricConv2DConvForward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]): # pylint: disable=redefined-builtin
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0) # pylint: disable=protected-access
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3]) # pylint: disable=protected-access
working = F.pad(input, self.paddingX, mode=modex)
working = F.pad(working, self.paddingY, mode=modex)
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
class Script(scripts_manager.Script):
def __init__(self):
super().__init__()
self.orig_pipe = None
self.conv_layers = []
self.modes = ['constant', 'circular', 'reflect', 'replicate']
def title(self):
return 'Asymmetric Tiling'
def show(self, is_img2img):
return True
def ui(self, _is_img2img): # ui elements
with gr.Row():
gr.HTML('<b>Asymmetric Tiling</b><br>')
with gr.Row():
tilex = gr.Dropdown(label="Mode x-axis", choices=self.modes, value='constant')
numx = gr.Slider(label="Repeat x-axis", value=1, minimum=1, maximum=10, step=1)
with gr.Row():
tiley = gr.Dropdown(label="Mode y-axis", choices=self.modes, value='constant')
numy = gr.Slider(label="Repeat y-axis", value=1, minimum=1, maximum=10, step=1)
return [tilex, numx, tiley, numy]
def run(self, p: processing.StableDiffusionProcessing, tilex:bool=False, numx:int=1, tiley:bool=False, numy:int=1): # pylint: disable=arguments-differ, unused-argument
global modex, modey # pylint: disable=global-statement
supported_model_list = ['sd', 'sdxl']
if shared.sd_model_type not in supported_model_list:
shared.log.warning(f'Tiling: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}')
return None
if not tilex and not tiley:
return None
self.orig_pipe = shared.sd_model
modex = tilex
modey = tiley
self.conv_layers.clear()
targets = [shared.sd_model.vae, shared.sd_model.text_encoder, shared.sd_model.unet]
for target in targets:
for module in target.modules():
if isinstance(module, torch.nn.Conv2d):
self.conv_layers.append(module)
for cl in self.conv_layers:
if isinstance(cl, LoRACompatibleConv) and cl.lora_layer is None:
cl.lora_layer = lambda *x: 0
if hasattr(cl, '_conv_forward'):
cl._orig_conv_forward = cl._conv_forward # pylint: disable=protected-access
cl._conv_forward = asymmetricConv2DConvForward.__get__(cl, torch.nn.Conv2d) # pylint: disable=protected-access, no-value-for-parameter
shared.log.info(f'Tiling: x={tilex}:{numx} y={tiley}:{numy}')
return None
def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, tilex:bool=False, numx:int=1, tiley:bool=False, numy:int=1): # pylint: disable=arguments-differ, unused-argument
if len(self.conv_layers) == 0:
return processed
for cl in self.conv_layers:
if hasattr(cl, '_orig_conv_forward'):
cl._conv_forward = cl._orig_conv_forward # pylint: disable=protected-access
if self.orig_pipe is None:
return processed
if shared.sd_model_type == "sdxl":
shared.sd_model = self.orig_pipe
self.orig_pipe = None
self.conv_layers.clear()
if not hasattr(processed, 'images') or processed.images is None:
return processed
images = []
for image in processed.images:
if tilex and isinstance(image, Image.Image):
tiled = Image.new('RGB', (image.width * numx, image.height), (0, 0, 0))
for i in range(numx):
tiled.paste(image, (i * image.width, 0))
image = tiled
if tiley and isinstance(image, Image.Image):
tiled = Image.new('RGB', (image.width, image.height * numy), (0, 0, 0))
for i in range(numy):
tiled.paste(image, (0, i * image.height))
image = tiled
images.append(image)
processed.images = images
return processed