mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
16
CHANGELOG.md
16
CHANGELOG.md
@@ -1,6 +1,6 @@
|
||||
# Change Log for SD.Next
|
||||
|
||||
## Update for 2024-12-13
|
||||
## Update for 2024-12-15
|
||||
|
||||
### New models and integrations
|
||||
|
||||
@@ -32,13 +32,19 @@
|
||||
enter multiple prompts in prompt field separated by new line
|
||||
style-aligned applies selected attention layers uniformly to all images to achive consistency
|
||||
can be used with or without input image in which case first prompt is used to establish baseline
|
||||
*note:* all prompts are processes as a single batch, so vram is limiting factor
|
||||
*note:* all prompts are processes as a single batch, so vram is limiting factor
|
||||
- [FreeScale](https://github.com/ali-vilab/FreeScale)
|
||||
enable in scripts, compatible with sd-xl for text and img2img
|
||||
run iterative generation of images at different scales to achieve better results
|
||||
can render 4k sdxl images
|
||||
*note*: disable live preview to avoid memory issues when generating large images
|
||||
- **ControlNet**
|
||||
- improved support for `Union` controlnets with granular control mode type
|
||||
- improved support for **Union** controlnets with granular control mode type
|
||||
- added support for latest [Xinsir ProMax](https://huggingface.co/xinsir/controlnet-union-sdxl-1.0) all-in-one controlnet
|
||||
- added support for multiple **Tiling** controlnets, for example [Xinsir Tile](https://huggingface.co/xinsir/controlnet-tile-sdxl-1.0)
|
||||
*note*: when selecting tiles in control settings, you can also specify non-square ratios
|
||||
in which case it will use context-aware image resize to maintain overall composition
|
||||
in which case it will use context-aware image resize to maintain overall composition
|
||||
*note*: available tiling options can be set in settings -> control
|
||||
|
||||
### UI and workflow improvements
|
||||
|
||||
@@ -118,6 +124,8 @@
|
||||
- fix cogvideox-i2v
|
||||
- lora auto-apply tags remove duplicates
|
||||
- control load model on-demand if not already loaded
|
||||
- taesd limit render to 2024px
|
||||
- taesd downscale preview to 1024px max
|
||||
|
||||
## Update for 2024-11-21
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from PIL import Image
|
||||
from modules import shared, processing, images, sd_models
|
||||
|
||||
@@ -17,30 +18,25 @@ def set_tile(image: Image.Image, x: int, y: int, tiled: Image.Image):
|
||||
|
||||
|
||||
def run_tiling(p: processing.StableDiffusionProcessing, input_image: Image.Image) -> processing.Processed:
|
||||
t0 = time.time()
|
||||
# prepare images
|
||||
sx, sy = p.control_tile.split('x')
|
||||
sx = int(sx)
|
||||
sy = int(sy)
|
||||
if sx <= 0 or sy <= 0:
|
||||
raise ValueError('Control: invalid tile size')
|
||||
raise ValueError('Control Tile: invalid tile size')
|
||||
control_image = p.task_args.get('control_image', None) or p.task_args.get('image', None)
|
||||
control_upscaled = None
|
||||
if isinstance(control_image, list) and len(control_image) > 0:
|
||||
control_upscaled = images.resize_image(resize_mode=1 if sx==sy else 5,
|
||||
im=control_image[0],
|
||||
width=8 * int(sx * control_image[0].width) // 8,
|
||||
height=8 * int(sy * control_image[0].height) // 8,
|
||||
context='add with forward'
|
||||
)
|
||||
w, h = 8 * int(sx * control_image[0].width) // 8, 8 * int(sy * control_image[0].height) // 8
|
||||
control_upscaled = images.resize_image(resize_mode=1 if sx==sy else 5, im=control_image[0], width=w, height=h, context='add with forward')
|
||||
init_image = p.override or input_image
|
||||
init_upscaled = None
|
||||
if init_image is not None:
|
||||
init_upscaled = images.resize_image(resize_mode=1 if sx==sy else 5,
|
||||
im=init_image,
|
||||
width=8 * int(sx * init_image.width) // 8,
|
||||
height=8 * int(sy * init_image.height) // 8,
|
||||
context='add with forward'
|
||||
)
|
||||
w, h = 8 * int(sx * init_image.width) // 8, 8 * int(sy * init_image.height) // 8
|
||||
init_upscaled = images.resize_image(resize_mode=1 if sx==sy else 5, im=init_image, width=w, height=h, context='add with forward')
|
||||
t1 = time.time()
|
||||
shared.log.debug(f'Control Tile: scale={sx}x{sy} resize={"fixed" if sx==sy else "context"} control={control_upscaled} init={init_upscaled} time={t1-t0:.3f}')
|
||||
|
||||
# stop processing from restoring pipeline on each iteration
|
||||
orig_restore_pipeline = getattr(shared.sd_model, 'restore_pipeline', None)
|
||||
@@ -72,4 +68,6 @@ def run_tiling(p: processing.StableDiffusionProcessing, input_image: Image.Image
|
||||
shared.sd_model.restore_pipeline = orig_restore_pipeline
|
||||
if hasattr(shared.sd_model, 'restore_pipeline') and shared.sd_model.restore_pipeline is not None:
|
||||
shared.sd_model.restore_pipeline()
|
||||
t2 = time.time()
|
||||
shared.log.debug(f'Control Tile: image={control_upscaled} time={t2-t0:.3f}')
|
||||
return processed
|
||||
|
||||
@@ -101,6 +101,14 @@ predefined_sd3 = {
|
||||
"Alimama Inpainting": 'alimama-creative/SD3-Controlnet-Inpainting',
|
||||
"Alimama SoftEdge": 'alimama-creative/SD3-Controlnet-Softedge',
|
||||
}
|
||||
variants = {
|
||||
'NoobAI Canny XL': 'fp16',
|
||||
'NoobAI Lineart Anime XL': 'fp16',
|
||||
'NoobAI Depth XL': 'fp16',
|
||||
'NoobAI Normal XL': 'fp16',
|
||||
'NoobAI SoftEdge XL': 'fp16',
|
||||
'TTPlanet Tile Realistic XL': 'fp16',
|
||||
}
|
||||
models = {}
|
||||
all_models = {}
|
||||
all_models.update(predefined_sd15)
|
||||
@@ -261,8 +269,8 @@ class ControlNet():
|
||||
if cls is None:
|
||||
log.error(f'Control {what} model load failed: id="{model_id}" unknown base model')
|
||||
return
|
||||
if 'Eugeoter' in model_path:
|
||||
kwargs['variant'] = 'fp16'
|
||||
if variants.get(model_id, None) is not None:
|
||||
kwargs['variant'] = variants[model_id]
|
||||
self.model = cls.from_pretrained(model_path, **self.load_config, **kwargs)
|
||||
if self.model is None:
|
||||
return
|
||||
|
||||
4
modules/freescale/__init__.py
Normal file
4
modules/freescale/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Credits: https://github.com/ali-vilab/FreeScale
|
||||
|
||||
from .freescale_pipeline import StableDiffusionXLFreeScale
|
||||
from .freescale_pipeline_img2img import StableDiffusionXLFreeScaleImg2Img
|
||||
305
modules/freescale/free_lunch_utils.py
Normal file
305
modules/freescale/free_lunch_utils.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
import torch.fft as fft
|
||||
from diffusers.utils import is_torch_version
|
||||
|
||||
""" Borrowed from https://github.com/ChenyangSi/FreeU/blob/main/demo/free_lunch_utils.py
|
||||
"""
|
||||
|
||||
def isinstance_str(x: object, cls_name: str):
|
||||
"""
|
||||
Checks whether x has any class *named* cls_name in its ancestry.
|
||||
Doesn't require access to the class's implementation.
|
||||
|
||||
Useful for patching!
|
||||
"""
|
||||
|
||||
for _cls in x.__class__.__mro__:
|
||||
if _cls.__name__ == cls_name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def Fourier_filter(x, threshold, scale):
|
||||
dtype = x.dtype
|
||||
x = x.type(torch.float32)
|
||||
# FFT
|
||||
x_freq = fft.fftn(x, dim=(-2, -1))
|
||||
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
||||
|
||||
B, C, H, W = x_freq.shape
|
||||
mask = torch.ones((B, C, H, W)).cuda()
|
||||
|
||||
crow, ccol = H // 2, W //2
|
||||
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
|
||||
x_freq = x_freq * mask
|
||||
|
||||
# IFFT
|
||||
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
||||
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
||||
|
||||
x_filtered = x_filtered.type(dtype)
|
||||
return x_filtered
|
||||
|
||||
|
||||
def register_upblock2d(model):
|
||||
def up_forward(self):
|
||||
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
#print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
for i, upsample_block in enumerate(model.unet.up_blocks):
|
||||
if isinstance_str(upsample_block, "UpBlock2D"):
|
||||
upsample_block.forward = up_forward(upsample_block)
|
||||
|
||||
|
||||
def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
||||
def up_forward(self):
|
||||
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
#print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
|
||||
|
||||
# --------------- FreeU code -----------------------
|
||||
# Only operate on the first two stages
|
||||
if hidden_states.shape[1] == 1280:
|
||||
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
||||
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
||||
if hidden_states.shape[1] == 640:
|
||||
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
||||
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
||||
# ---------------------------------------------------------
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
for i, upsample_block in enumerate(model.unet.up_blocks):
|
||||
if isinstance_str(upsample_block, "UpBlock2D"):
|
||||
upsample_block.forward = up_forward(upsample_block)
|
||||
setattr(upsample_block, 'b1', b1)
|
||||
setattr(upsample_block, 'b2', b2)
|
||||
setattr(upsample_block, 's1', s1)
|
||||
setattr(upsample_block, 's2', s2)
|
||||
|
||||
|
||||
def register_crossattn_upblock2d(model):
|
||||
def up_forward(self):
|
||||
def forward(
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
#print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
None, # timestep
|
||||
None, # class_labels
|
||||
cross_attention_kwargs,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
for i, upsample_block in enumerate(model.unet.up_blocks):
|
||||
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
|
||||
upsample_block.forward = up_forward(upsample_block)
|
||||
|
||||
|
||||
def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
||||
def up_forward(self):
|
||||
def forward(
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
#print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# --------------- FreeU code -----------------------
|
||||
# Only operate on the first two stages
|
||||
if hidden_states.shape[1] == 1280:
|
||||
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
||||
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
||||
if hidden_states.shape[1] == 640:
|
||||
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
||||
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
||||
# ---------------------------------------------------------
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
None, # timestep
|
||||
None, # class_labels
|
||||
cross_attention_kwargs,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
# hidden_states = attn(
|
||||
# hidden_states,
|
||||
# encoder_hidden_states=encoder_hidden_states,
|
||||
# cross_attention_kwargs=cross_attention_kwargs,
|
||||
# encoder_attention_mask=encoder_attention_mask,
|
||||
# return_dict=False,
|
||||
# )[0]
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)[0]
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
return forward
|
||||
|
||||
for i, upsample_block in enumerate(model.unet.up_blocks):
|
||||
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
|
||||
upsample_block.forward = up_forward(upsample_block)
|
||||
setattr(upsample_block, 'b1', b1)
|
||||
setattr(upsample_block, 'b2', b2)
|
||||
setattr(upsample_block, 's1', s1)
|
||||
setattr(upsample_block, 's2', s2)
|
||||
1189
modules/freescale/freescale_pipeline.py
Normal file
1189
modules/freescale/freescale_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
1245
modules/freescale/freescale_pipeline_img2img.py
Normal file
1245
modules/freescale/freescale_pipeline_img2img.py
Normal file
File diff suppressed because it is too large
Load Diff
367
modules/freescale/scale_attention.py
Normal file
367
modules/freescale/scale_attention.py
Normal file
@@ -0,0 +1,367 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import random
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def gaussian_kernel(kernel_size=3, sigma=1.0, channels=3):
|
||||
x_coord = torch.arange(kernel_size)
|
||||
gaussian_1d = torch.exp(-(x_coord - (kernel_size - 1) / 2) ** 2 / (2 * sigma ** 2))
|
||||
gaussian_1d = gaussian_1d / gaussian_1d.sum()
|
||||
gaussian_2d = gaussian_1d[:, None] * gaussian_1d[None, :]
|
||||
kernel = gaussian_2d[None, None, :, :].repeat(channels, 1, 1, 1)
|
||||
|
||||
return kernel
|
||||
|
||||
def gaussian_filter(latents, kernel_size=3, sigma=1.0):
|
||||
channels = latents.shape[1]
|
||||
kernel = gaussian_kernel(kernel_size, sigma, channels).to(latents.device, latents.dtype)
|
||||
blurred_latents = F.conv2d(latents, kernel, padding=kernel_size//2, groups=channels)
|
||||
|
||||
return blurred_latents
|
||||
|
||||
def get_views(height, width, h_window_size=128, w_window_size=128, scale_factor=8):
|
||||
height = int(height)
|
||||
width = int(width)
|
||||
h_window_stride = h_window_size // 2
|
||||
w_window_stride = w_window_size // 2
|
||||
h_window_size = int(h_window_size / scale_factor)
|
||||
w_window_size = int(w_window_size / scale_factor)
|
||||
h_window_stride = int(h_window_stride / scale_factor)
|
||||
w_window_stride = int(w_window_stride / scale_factor)
|
||||
num_blocks_height = int((height - h_window_size) / h_window_stride - 1e-6) + 2 if height > h_window_size else 1
|
||||
num_blocks_width = int((width - w_window_size) / w_window_stride - 1e-6) + 2 if width > w_window_size else 1
|
||||
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
||||
views = []
|
||||
for i in range(total_num_blocks):
|
||||
h_start = int((i // num_blocks_width) * h_window_stride)
|
||||
h_end = h_start + h_window_size
|
||||
w_start = int((i % num_blocks_width) * w_window_stride)
|
||||
w_end = w_start + w_window_size
|
||||
|
||||
if h_end > height:
|
||||
h_start = int(h_start + height - h_end)
|
||||
h_end = int(height)
|
||||
if w_end > width:
|
||||
w_start = int(w_start + width - w_end)
|
||||
w_end = int(width)
|
||||
if h_start < 0:
|
||||
h_end = int(h_end - h_start)
|
||||
h_start = 0
|
||||
if w_start < 0:
|
||||
w_end = int(w_end - w_start)
|
||||
w_start = 0
|
||||
|
||||
random_jitter = True
|
||||
if random_jitter:
|
||||
h_jitter_range = h_window_size // 8
|
||||
w_jitter_range = w_window_size // 8
|
||||
h_jitter = 0
|
||||
w_jitter = 0
|
||||
|
||||
if (w_start != 0) and (w_end != width):
|
||||
w_jitter = random.randint(-w_jitter_range, w_jitter_range)
|
||||
elif (w_start == 0) and (w_end != width):
|
||||
w_jitter = random.randint(-w_jitter_range, 0)
|
||||
elif (w_start != 0) and (w_end == width):
|
||||
w_jitter = random.randint(0, w_jitter_range)
|
||||
if (h_start != 0) and (h_end != height):
|
||||
h_jitter = random.randint(-h_jitter_range, h_jitter_range)
|
||||
elif (h_start == 0) and (h_end != height):
|
||||
h_jitter = random.randint(-h_jitter_range, 0)
|
||||
elif (h_start != 0) and (h_end == height):
|
||||
h_jitter = random.randint(0, h_jitter_range)
|
||||
h_start += (h_jitter + h_jitter_range)
|
||||
h_end += (h_jitter + h_jitter_range)
|
||||
w_start += (w_jitter + w_jitter_range)
|
||||
w_end += (w_jitter + w_jitter_range)
|
||||
|
||||
views.append((h_start, h_end, w_start, w_end))
|
||||
return views
|
||||
|
||||
def scale_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
if self.current_hw:
|
||||
current_scale_num_h, current_scale_num_w = max(self.current_hw[0] // 1024, 1), max(self.current_hw[1] // 1024, 1)
|
||||
else:
|
||||
current_scale_num_h, current_scale_num_w = 1, 1
|
||||
|
||||
# 0. Self-Attention
|
||||
if self.use_ada_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
# 2. Prepare GLIGEN inputs
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
|
||||
ratio_hw = current_scale_num_h / current_scale_num_w
|
||||
latent_h = int((norm_hidden_states.shape[1] * ratio_hw) ** 0.5)
|
||||
latent_w = int(latent_h / ratio_hw)
|
||||
scale_factor = 128 * current_scale_num_h / latent_h
|
||||
if ratio_hw > 1:
|
||||
sub_h = 128
|
||||
sub_w = int(128 / ratio_hw)
|
||||
else:
|
||||
sub_h = int(128 * ratio_hw)
|
||||
sub_w = 128
|
||||
|
||||
h_jitter_range = int(sub_h / scale_factor // 8)
|
||||
w_jitter_range = int(sub_w / scale_factor // 8)
|
||||
views = get_views(latent_h, latent_w, sub_h, sub_w, scale_factor = scale_factor)
|
||||
|
||||
current_scale_num = max(current_scale_num_h, current_scale_num_w)
|
||||
global_views = [[h, w] for h in range(current_scale_num_h) for w in range(current_scale_num_w)]
|
||||
|
||||
four_window = True
|
||||
fourg_window = False
|
||||
|
||||
if four_window:
|
||||
norm_hidden_states_ = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
||||
norm_hidden_states_ = F.pad(norm_hidden_states_, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
|
||||
value = torch.zeros_like(norm_hidden_states_)
|
||||
count = torch.zeros_like(norm_hidden_states_)
|
||||
for index, view in enumerate(views):
|
||||
h_start, h_end, w_start, w_end = view
|
||||
local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
|
||||
local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
|
||||
local_output = self.attn1(
|
||||
local_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
|
||||
|
||||
value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
|
||||
count[:, h_start:h_end, w_start:w_end, :] += 1
|
||||
|
||||
value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
||||
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
||||
attn_output = torch.where(count>0, value/count, value)
|
||||
|
||||
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
||||
|
||||
attn_output_global = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
attn_output_global = rearrange(attn_output_global, 'bh (h w) d -> bh h w d', h = latent_h)
|
||||
|
||||
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
||||
|
||||
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
||||
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
|
||||
|
||||
elif fourg_window:
|
||||
norm_hidden_states = rearrange(norm_hidden_states, 'bh (h w) d -> bh h w d', h = latent_h)
|
||||
norm_hidden_states_ = F.pad(norm_hidden_states, (0, 0, w_jitter_range, w_jitter_range, h_jitter_range, h_jitter_range), 'constant', 0)
|
||||
value = torch.zeros_like(norm_hidden_states_)
|
||||
count = torch.zeros_like(norm_hidden_states_)
|
||||
for index, view in enumerate(views):
|
||||
h_start, h_end, w_start, w_end = view
|
||||
local_states = norm_hidden_states_[:, h_start:h_end, w_start:w_end, :]
|
||||
local_states = rearrange(local_states, 'bh h w d -> bh (h w) d')
|
||||
local_output = self.attn1(
|
||||
local_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
local_output = rearrange(local_output, 'bh (h w) d -> bh h w d', h = int(sub_h / scale_factor))
|
||||
|
||||
value[:, h_start:h_end, w_start:w_end, :] += local_output * 1
|
||||
count[:, h_start:h_end, w_start:w_end, :] += 1
|
||||
|
||||
value = value[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
||||
count = count[:, h_jitter_range:-h_jitter_range, w_jitter_range:-w_jitter_range, :]
|
||||
attn_output = torch.where(count>0, value/count, value)
|
||||
|
||||
gaussian_local = gaussian_filter(attn_output, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
||||
|
||||
value = torch.zeros_like(norm_hidden_states)
|
||||
count = torch.zeros_like(norm_hidden_states)
|
||||
for index, global_view in enumerate(global_views):
|
||||
h, w = global_view
|
||||
global_states = norm_hidden_states[:, h::current_scale_num_h, w::current_scale_num_w, :]
|
||||
global_states = rearrange(global_states, 'bh h w d -> bh (h w) d')
|
||||
global_output = self.attn1(
|
||||
global_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
global_output = rearrange(global_output, 'bh (h w) d -> bh h w d', h = int(global_output.shape[1] ** 0.5))
|
||||
|
||||
value[:, h::current_scale_num_h, w::current_scale_num_w, :] += global_output * 1
|
||||
count[:, h::current_scale_num_h, w::current_scale_num_w, :] += 1
|
||||
|
||||
attn_output_global = torch.where(count>0, value/count, value)
|
||||
|
||||
gaussian_global = gaussian_filter(attn_output_global, kernel_size=(2*current_scale_num-1), sigma=1.0)
|
||||
|
||||
attn_output = gaussian_local + (attn_output_global - gaussian_global)
|
||||
attn_output = rearrange(attn_output, 'bh h w d -> bh (h w) d')
|
||||
|
||||
else:
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2.5 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
||||
# 2.5 ends
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
|
||||
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
||||
ff_output = torch.cat(
|
||||
[
|
||||
self.ff(hid_slice)
|
||||
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
|
||||
],
|
||||
dim=self._chunk_dim,
|
||||
)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
def ori_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
if self.use_ada_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
# 2. Prepare GLIGEN inputs
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
||||
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
||||
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2.5 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
|
||||
# 2.5 ends
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.attn2 is not None:
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
||||
raise ValueError(
|
||||
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
||||
)
|
||||
|
||||
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
||||
ff_output = torch.cat(
|
||||
[
|
||||
self.ff(hid_slice)
|
||||
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
|
||||
],
|
||||
dim=self._chunk_dim,
|
||||
)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
@@ -365,13 +365,26 @@ def process_decode(p: processing.StableDiffusionProcessing, output):
|
||||
else:
|
||||
width = getattr(p, 'width', 0)
|
||||
height = getattr(p, 'height', 0)
|
||||
results = processing_vae.vae_decode(
|
||||
latents = output.images,
|
||||
model = model,
|
||||
full_quality = p.full_quality,
|
||||
width = width,
|
||||
height = height,
|
||||
)
|
||||
if isinstance(output.images, list):
|
||||
results = []
|
||||
for i in range(len(output.images)):
|
||||
result_batch = processing_vae.vae_decode(
|
||||
latents = output.images[i],
|
||||
model = model,
|
||||
full_quality = p.full_quality,
|
||||
width = width,
|
||||
height = height,
|
||||
)
|
||||
for result in list(result_batch):
|
||||
results.append(result)
|
||||
else:
|
||||
results = processing_vae.vae_decode(
|
||||
latents = output.images,
|
||||
model = model,
|
||||
full_quality = p.full_quality,
|
||||
width = width,
|
||||
height = height,
|
||||
)
|
||||
elif hasattr(output, 'images'):
|
||||
results = output.images
|
||||
else:
|
||||
|
||||
@@ -40,7 +40,6 @@ def single_sample_to_image(sample, approximation=None):
|
||||
if approximation is None:
|
||||
warn_once('Unknown decode type')
|
||||
approximation = 0
|
||||
# normal sample is [4,64,64]
|
||||
try:
|
||||
if sample.dtype == torch.bfloat16 and (approximation == 0 or approximation == 1):
|
||||
sample = sample.to(torch.float16)
|
||||
@@ -62,6 +61,9 @@ def single_sample_to_image(sample, approximation=None):
|
||||
sample = sample * (5 / abs(sample_min))
|
||||
"""
|
||||
if approximation == 2: # TAESD
|
||||
if sample.shape[-1] > 128 or sample.shape[-2] > 128:
|
||||
scale = 128 / max(sample.shape[-1], sample.shape[-2])
|
||||
sample = torch.nn.functional.interpolate(sample.unsqueeze(0), scale_factor=[scale, scale], mode='bilinear', align_corners=False)[0]
|
||||
x_sample = sd_vae_taesd.decode(sample)
|
||||
x_sample = (1.0 + x_sample) / 2.0 # preview requires smaller range
|
||||
elif shared.sd_model_type == 'sc' and approximation != 3:
|
||||
|
||||
@@ -169,6 +169,9 @@ def decode(latents):
|
||||
if vae is None:
|
||||
return latents
|
||||
try:
|
||||
size = max(latents.shape[-1], latents.shape[-2])
|
||||
if size > 256:
|
||||
return latents
|
||||
with devices.inference_context():
|
||||
latents = latents.detach().clone().to(devices.device, dtype)
|
||||
if len(latents.shape) == 3:
|
||||
|
||||
@@ -873,6 +873,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), {
|
||||
|
||||
options_templates.update(options_section(('control', "Control Options"), {
|
||||
"control_max_units": OptionInfo(4, "Maximum number of units", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
|
||||
"control_tiles": OptionInfo("1x1, 1x2, 1x3, 1x4, 2x1, 2x1, 2x2, 2x3, 2x4, 3x1, 3x2, 3x3, 3x4, 4x1, 4x2, 4x3, 4x4", "Tiling options"),
|
||||
"control_move_processor": OptionInfo(False, "Processor move to CPU after use"),
|
||||
"control_unload_processor": OptionInfo(False, "Processor unload after use"),
|
||||
}))
|
||||
|
||||
@@ -141,9 +141,9 @@ class State:
|
||||
if self.job == 'VAE': # avoid generating preview while vae is running
|
||||
return
|
||||
from modules.shared import opts, cmd_opts
|
||||
if cmd_opts.lowvram or self.api:
|
||||
if cmd_opts.lowvram or self.api or not opts.live_previews_enable or opts.show_progress_every_n_steps <= 0:
|
||||
return
|
||||
if abs(self.sampling_step - self.current_image_sampling_step) >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps > 0:
|
||||
if abs(self.sampling_step - self.current_image_sampling_step) >= opts.show_progress_every_n_steps:
|
||||
self.do_set_current_image()
|
||||
|
||||
def do_set_current_image(self):
|
||||
|
||||
@@ -254,7 +254,7 @@ def create_ui(_blocks: gr.Blocks=None):
|
||||
control_start = gr.Slider(label="CN Start", minimum=0.0, maximum=1.0, step=0.05, value=0, elem_id=f'control_unit-{i}-start')
|
||||
control_end = gr.Slider(label="CN End", minimum=0.0, maximum=1.0, step=0.05, value=1.0, elem_id=f'control_unit-{i}-end')
|
||||
control_mode = gr.Dropdown(label="CN Mode", choices=['default'], value='default', visible=False, elem_id=f'control_unit-{i}-mode')
|
||||
control_tile = gr.Dropdown(label="CN Tiles", choices=['1x1', '1x2', '1x3', '1x4', '2x1', '2x1', '2x2', '2x3', '2x4', '3x1', '3x2', '3x3', '3x4', '4x1', '4x2', '4x3', '4x4'], value='1x1', visible=False, elem_id=f'control_unit-{i}-tile')
|
||||
control_tile = gr.Dropdown(label="CN Tiles", choices=[x.strip() for x in shared.opts.control_tiles.split(',') if 'x' in x], value='1x1', visible=False, elem_id=f'control_unit-{i}-tile')
|
||||
reset_btn = ui_components.ToolButton(value=ui_symbols.reset)
|
||||
image_upload = gr.UploadButton(label=ui_symbols.upload, file_types=['image'], elem_classes=['form', 'gradio-button', 'tool'])
|
||||
image_reuse= ui_components.ToolButton(value=ui_symbols.reuse)
|
||||
|
||||
130
scripts/freescale.py
Normal file
130
scripts/freescale.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import gradio as gr
|
||||
from modules import scripts, processing, shared, sd_models
|
||||
|
||||
|
||||
registered = False
|
||||
|
||||
|
||||
class Script(scripts.Script):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.orig_pipe = None
|
||||
self.orig_slice = None
|
||||
self.orig_tile = None
|
||||
self.is_img2img = False
|
||||
|
||||
def title(self):
|
||||
return 'FreeScale: Tuning-Free Scale Fusion'
|
||||
|
||||
def show(self, is_img2img):
|
||||
self.is_img2img = is_img2img
|
||||
return shared.native
|
||||
|
||||
def ui(self, _is_img2img): # ui elements
|
||||
with gr.Row():
|
||||
gr.HTML('<a href="https://github.com/ali-vilab/FreeScale">  FreeScale: Tuning-Free Scale Fusion</a><br>')
|
||||
with gr.Row():
|
||||
cosine_scale = gr.Slider(minimum=0.1, maximum=5.0, value=2.0, label='Cosine scale')
|
||||
override_sampler = gr.Checkbox(value=True, label='Override sampler')
|
||||
with gr.Row(visible=self.is_img2img):
|
||||
cosine_scale_bg = gr.Slider(minimum=0.1, maximum=5.0, value=1.0, label='Cosine Background')
|
||||
dilate_tau = gr.Slider(minimum=1, maximum=100, value=35, label='Dilate tau')
|
||||
with gr.Row():
|
||||
s1_enable = gr.Checkbox(value=True, label='1st Stage', interactive=False)
|
||||
s1_scale = gr.Slider(minimum=1, maximum=8.0, value=1.0, label='Scale')
|
||||
s1_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
|
||||
with gr.Row():
|
||||
s2_enable = gr.Checkbox(value=True, label='2nd Stage')
|
||||
s2_scale = gr.Slider(minimum=1, maximum=8.0, value=2.0, label='Scale')
|
||||
s2_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
|
||||
with gr.Row():
|
||||
s3_enable = gr.Checkbox(value=False, label='3rd Stage')
|
||||
s3_scale = gr.Slider(minimum=1, maximum=8.0, value=3.0, label='Scale')
|
||||
s3_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
|
||||
with gr.Row():
|
||||
s4_enable = gr.Checkbox(value=False, label='4th Stage')
|
||||
s4_scale = gr.Slider(minimum=1, maximum=8.0, value=4.0, label='Scale')
|
||||
s4_restart = gr.Slider(minimum=0, maximum=1.0, value=0.75, label='Restart step')
|
||||
return [cosine_scale, override_sampler, cosine_scale_bg, dilate_tau, s1_enable, s1_scale, s1_restart, s2_enable, s2_scale, s2_restart, s3_enable, s3_scale, s3_restart, s4_enable, s4_scale, s4_restart]
|
||||
|
||||
def run(self, p: processing.StableDiffusionProcessing, cosine_scale, override_sampler, cosine_scale_bg, dilate_tau, s1_enable, s1_scale, s1_restart, s2_enable, s2_scale, s2_restart, s3_enable, s3_scale, s3_restart, s4_enable, s4_scale, s4_restart): # pylint: disable=arguments-differ
|
||||
supported_model_list = ['sdxl']
|
||||
if shared.sd_model_type not in supported_model_list:
|
||||
shared.log.warning(f'FreeScale: class={shared.sd_model.__class__.__name__} model={shared.sd_model_type} required={supported_model_list}')
|
||||
return None
|
||||
|
||||
if self.is_img2img:
|
||||
if p.init_images is None or len(p.init_images) == 0:
|
||||
shared.log.warning('FreeScale: missing input image')
|
||||
return None
|
||||
|
||||
from modules.freescale import StableDiffusionXLFreeScale, StableDiffusionXLFreeScaleImg2Img
|
||||
self.orig_pipe = shared.sd_model
|
||||
self.orig_slice = shared.opts.diffusers_vae_slicing
|
||||
self.orig_tile = shared.opts.diffusers_vae_tiling
|
||||
|
||||
def scale(x):
|
||||
if (p.width == 0 or p.height == 0) and p.init_images is not None:
|
||||
p.width, p.height = p.init_images[0].width, p.init_images[0].height
|
||||
resolution = [int(8 * p.width * x // 8), int(8 * p.height * x // 8)]
|
||||
return resolution
|
||||
|
||||
scales = []
|
||||
resolutions_list = []
|
||||
restart_steps = []
|
||||
if s1_enable:
|
||||
scales.append(s1_scale)
|
||||
resolutions_list.append(scale(s1_scale))
|
||||
restart_steps.append(int(p.steps * s1_restart))
|
||||
if s2_enable and s2_scale > s1_scale:
|
||||
scales.append(s2_scale)
|
||||
resolutions_list.append(scale(s2_scale))
|
||||
restart_steps.append(int(p.steps * s2_restart))
|
||||
if s3_enable and s3_scale > s2_scale:
|
||||
scales.append(s3_scale)
|
||||
resolutions_list.append(scale(s3_scale))
|
||||
restart_steps.append(int(p.steps * s3_restart))
|
||||
if s4_enable and s4_scale > s3_scale:
|
||||
scales.append(s4_scale)
|
||||
resolutions_list.append(scale(s4_scale))
|
||||
restart_steps.append(int(p.steps * s4_restart))
|
||||
|
||||
p.task_args['resolutions_list'] = resolutions_list
|
||||
p.task_args['cosine_scale'] = cosine_scale
|
||||
p.task_args['restart_steps'] = [min(max(1, step), p.steps-1) for step in restart_steps]
|
||||
if self.is_img2img:
|
||||
p.task_args['cosine_scale_bg'] = cosine_scale_bg
|
||||
p.task_args['dilate_tau'] = dilate_tau
|
||||
p.task_args['img_path'] = p.init_images[0]
|
||||
p.init_images = None
|
||||
if override_sampler:
|
||||
p.sampler_name = 'Euler a'
|
||||
|
||||
if p.width < 1024 or p.height < 1024:
|
||||
shared.log.error(f'FreeScale: width={p.width} height={p.height} minimum=1024')
|
||||
return None
|
||||
|
||||
if not self.is_img2img:
|
||||
shared.sd_model = sd_models.switch_pipe(StableDiffusionXLFreeScale, shared.sd_model)
|
||||
else:
|
||||
shared.sd_model = sd_models.switch_pipe(StableDiffusionXLFreeScaleImg2Img, shared.sd_model)
|
||||
shared.sd_model.enable_vae_slicing()
|
||||
shared.sd_model.enable_vae_tiling()
|
||||
|
||||
shared.log.info(f'FreeScale: mode={"txt" if not self.is_img2img else "img"} cosine={cosine_scale} bg={cosine_scale_bg} tau={dilate_tau} scales={scales} resolutions={resolutions_list} steps={restart_steps} sampler={p.sampler_name}')
|
||||
resolutions = ','.join([f'{x[0]}x{x[1]}' for x in resolutions_list])
|
||||
steps = ','.join([str(x) for x in restart_steps])
|
||||
p.extra_generation_params["FreeScale"] = f'cosine {cosine_scale} resolutions {resolutions} steps {steps}'
|
||||
|
||||
def after(self, p: processing.StableDiffusionProcessing, processed: processing.Processed, *args): # pylint: disable=arguments-differ, unused-argument
|
||||
if self.orig_pipe is None:
|
||||
return processed
|
||||
# restore pipeline
|
||||
if shared.sd_model_type == "sdxl":
|
||||
shared.sd_model = self.orig_pipe
|
||||
self.orig_pipe = None
|
||||
if not self.orig_slice:
|
||||
shared.sd_model.disable_vae_slicing()
|
||||
if not self.orig_tile:
|
||||
shared.sd_model.disable_vae_tiling()
|
||||
return processed
|
||||
Reference in New Issue
Block a user