mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
163 lines
7.8 KiB
Python
163 lines
7.8 KiB
Python
from typing import Union
|
|
import sys
|
|
import time
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from modules import shared, upscaler
|
|
|
|
|
|
def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None):
|
|
upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img
|
|
|
|
def verify_image(image):
|
|
try:
|
|
if isinstance(image, torch.Tensor):
|
|
image = image.float().detach().cpu().numpy()
|
|
if isinstance(image, np.ndarray):
|
|
if np.issubdtype(image.dtype, np.floating):
|
|
image = (255.0 * image).astype(np.uint8)
|
|
image = Image.fromarray(image)
|
|
except Exception as e:
|
|
shared.log.error(f"Image verification failed: {e}")
|
|
return image
|
|
|
|
def latent(im, scale: float, selected_upscaler: upscaler.UpscalerData):
|
|
if isinstance(im, torch.Tensor):
|
|
im = selected_upscaler.scaler.upscale(im, scale, selected_upscaler.name)
|
|
return im
|
|
else:
|
|
from modules.processing_vae import vae_encode, vae_decode
|
|
latents = vae_encode(im, shared.sd_model, vae_type='Tiny') # TODO resize image: enable full VAE mode for resize-latent
|
|
latents = selected_upscaler.scaler.upscale(latents, scale, selected_upscaler.name)
|
|
im = vae_decode(latents, shared.sd_model, output_type='pil', vae_type='Tiny')[0]
|
|
return im
|
|
|
|
def resize(im: Union[Image.Image, torch.Tensor], w, h):
|
|
w, h = int(w), int(h)
|
|
if upscaler_name is None or upscaler_name == "None" or (hasattr(im, 'mode') and im.mode == 'L'):
|
|
return im.resize((w, h), resample=Image.Resampling.LANCZOS) # force for mask
|
|
if isinstance(im, torch.Tensor):
|
|
scale = max(w // 8 / im.shape[-1] , h // 8 / im.shape[-2])
|
|
else:
|
|
scale = max(w / im.width, h / im.height)
|
|
if scale > 1.0:
|
|
upscalers = [x for x in shared.sd_upscalers if x.name.lower().replace('-', ' ') == upscaler_name.lower().replace('-', ' ')]
|
|
if len(upscalers) > 0:
|
|
selected_upscaler: upscaler.UpscalerData = upscalers[0]
|
|
if selected_upscaler.name.lower().startswith('latent'):
|
|
im = latent(im, scale, selected_upscaler)
|
|
else:
|
|
im = selected_upscaler.scaler.upscale(im, scale, selected_upscaler.name)
|
|
else:
|
|
shared.log.warning(f"Resize upscaler: invalid={upscaler_name} fallback={selected_upscaler.name}")
|
|
shared.log.debug(f"Resize upscaler: available={[u.name for u in shared.sd_upscalers]}")
|
|
if isinstance(im, Image.Image) and (im.width != w or im.height != h): # probably downsample after upscaler created larger image
|
|
im = im.resize((w, h), resample=Image.Resampling.LANCZOS)
|
|
return im
|
|
|
|
def crop(im: Image.Image):
|
|
ratio = width / height
|
|
src_ratio = im.width / im.height
|
|
src_w = width if ratio > src_ratio else im.width * height // im.height
|
|
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
|
resized = resize(im, src_w, src_h)
|
|
res = Image.new(im.mode, (width, height))
|
|
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
|
return res
|
|
|
|
def fill(im: Image.Image, color=None):
|
|
color = color or shared.opts.image_background
|
|
"""
|
|
ratio = round(width / height, 1)
|
|
src_ratio = round(im.width / im.height, 1)
|
|
src_w = width if ratio < src_ratio else im.width * height // im.height
|
|
src_h = height if ratio >= src_ratio else im.height * width // im.width
|
|
resized = resize(im, src_w, src_h)
|
|
res = Image.new(im.mode, (width, height))
|
|
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
|
if ratio < src_ratio:
|
|
fill_height = height // 2 - src_h // 2
|
|
if width > 0 and fill_height > 0:
|
|
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
|
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
|
|
elif ratio > src_ratio:
|
|
fill_width = width // 2 - src_w // 2
|
|
if height > 0 and fill_width > 0:
|
|
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
|
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
|
|
return res
|
|
"""
|
|
ratio = min(width / im.width, height / im.height)
|
|
im = resize(im, int(im.width * ratio), int(im.height * ratio))
|
|
res = Image.new(im.mode, (width, height), color=color)
|
|
res.paste(im, box=((width - im.width)//2, (height - im.height)//2))
|
|
return res
|
|
|
|
def context_aware(im: Image.Image, width, height, context):
|
|
from installer import install
|
|
install('seam-carving')
|
|
width, height = int(width), int(height)
|
|
import seam_carving # https://github.com/li-plus/seam-carving
|
|
if 'forward' in context.lower():
|
|
energy_mode = "forward"
|
|
elif 'backward' in context.lower():
|
|
energy_mode = "backward"
|
|
else:
|
|
return im
|
|
if 'add' in context.lower():
|
|
src_ratio = min(width / im.width, height / im.height)
|
|
src_w = int(im.width * src_ratio)
|
|
src_h = int(im.height * src_ratio)
|
|
src_image = resize(im, src_w, src_h)
|
|
elif 'remove' in context.lower():
|
|
ratio = width / height
|
|
src_ratio = im.width / im.height
|
|
src_w = width if ratio > src_ratio else im.width * height // im.height
|
|
src_h = height if ratio <= src_ratio else im.height * width // im.width
|
|
src_image = resize(im, src_w, src_h)
|
|
else:
|
|
return im
|
|
np_image = seam_carving.resize(
|
|
src_image, # source image (rgb or gray)
|
|
size=(width, height), # target size
|
|
energy_mode=energy_mode, # choose from {backward, forward}
|
|
order="width-first", # choose from {width-first, height-first}
|
|
keep_mask=None, # object mask to protect from removal
|
|
)
|
|
res = Image.fromarray(np_image)
|
|
return res
|
|
|
|
t0 = time.time()
|
|
if resize_mode is None:
|
|
resize_mode = 0
|
|
if isinstance(im, torch.Tensor): # latent resize only supports fixed mode
|
|
res = resize(im, width, height)
|
|
return res
|
|
im = verify_image(im)
|
|
if not isinstance(im, Image.Image):
|
|
shared.log.error(f'Image resize: image={type(im)} invalid type')
|
|
return im
|
|
if (resize_mode == 0) or ((im.width == width) and (im.height == height)) or (width == 0 and height == 0): # none
|
|
res = im.copy()
|
|
elif resize_mode == 1: # fixed
|
|
res = resize(im, width, height)
|
|
elif resize_mode == 2: # crop
|
|
res = crop(im)
|
|
elif resize_mode == 3: # fill
|
|
res = fill(im)
|
|
elif resize_mode == 4: # edge
|
|
from modules import masking
|
|
res = fill(im, color=0)
|
|
res, _mask = masking.outpaint(res)
|
|
elif resize_mode == 5: # context-aware
|
|
res = context_aware(im, width, height, context)
|
|
else:
|
|
res = im.copy()
|
|
shared.log.error(f'Invalid resize mode: {resize_mode}')
|
|
t1 = time.time()
|
|
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
|
|
if im.width != width or im.height != height:
|
|
shared.log.debug(f'Image resize: source={im.width}:{im.height} target={width}:{height} mode="{shared.resize_modes[resize_mode]}" upscaler="{upscaler_name}" type={output_type} time={t1-t0:.2f} fn={fn}') # pylint: disable=protected-access
|
|
return np.array(res) if output_type == 'np' else res
|