1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/images_resize.py
Vladimir Mandic f4b0656dbb reduce requirements
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-11-12 19:24:47 -05:00

163 lines
7.8 KiB
Python

from typing import Union
import sys
import time
import numpy as np
import torch
from PIL import Image
from modules import shared, upscaler
def resize_image(resize_mode: int, im: Union[Image.Image, torch.Tensor], width: int, height: int, upscaler_name: str=None, output_type: str='image', context: str=None):
upscaler_name = upscaler_name or shared.opts.upscaler_for_img2img
def verify_image(image):
try:
if isinstance(image, torch.Tensor):
image = image.float().detach().cpu().numpy()
if isinstance(image, np.ndarray):
if np.issubdtype(image.dtype, np.floating):
image = (255.0 * image).astype(np.uint8)
image = Image.fromarray(image)
except Exception as e:
shared.log.error(f"Image verification failed: {e}")
return image
def latent(im, scale: float, selected_upscaler: upscaler.UpscalerData):
if isinstance(im, torch.Tensor):
im = selected_upscaler.scaler.upscale(im, scale, selected_upscaler.name)
return im
else:
from modules.processing_vae import vae_encode, vae_decode
latents = vae_encode(im, shared.sd_model, vae_type='Tiny') # TODO resize image: enable full VAE mode for resize-latent
latents = selected_upscaler.scaler.upscale(latents, scale, selected_upscaler.name)
im = vae_decode(latents, shared.sd_model, output_type='pil', vae_type='Tiny')[0]
return im
def resize(im: Union[Image.Image, torch.Tensor], w, h):
w, h = int(w), int(h)
if upscaler_name is None or upscaler_name == "None" or (hasattr(im, 'mode') and im.mode == 'L'):
return im.resize((w, h), resample=Image.Resampling.LANCZOS) # force for mask
if isinstance(im, torch.Tensor):
scale = max(w // 8 / im.shape[-1] , h // 8 / im.shape[-2])
else:
scale = max(w / im.width, h / im.height)
if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name.lower().replace('-', ' ') == upscaler_name.lower().replace('-', ' ')]
if len(upscalers) > 0:
selected_upscaler: upscaler.UpscalerData = upscalers[0]
if selected_upscaler.name.lower().startswith('latent'):
im = latent(im, scale, selected_upscaler)
else:
im = selected_upscaler.scaler.upscale(im, scale, selected_upscaler.name)
else:
shared.log.warning(f"Resize upscaler: invalid={upscaler_name} fallback={selected_upscaler.name}")
shared.log.debug(f"Resize upscaler: available={[u.name for u in shared.sd_upscalers]}")
if isinstance(im, Image.Image) and (im.width != w or im.height != h): # probably downsample after upscaler created larger image
im = im.resize((w, h), resample=Image.Resampling.LANCZOS)
return im
def crop(im: Image.Image):
ratio = width / height
src_ratio = im.width / im.height
src_w = width if ratio > src_ratio else im.width * height // im.height
src_h = height if ratio <= src_ratio else im.height * width // im.width
resized = resize(im, src_w, src_h)
res = Image.new(im.mode, (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
return res
def fill(im: Image.Image, color=None):
color = color or shared.opts.image_background
"""
ratio = round(width / height, 1)
src_ratio = round(im.width / im.height, 1)
src_w = width if ratio < src_ratio else im.width * height // im.height
src_h = height if ratio >= src_ratio else im.height * width // im.width
resized = resize(im, src_w, src_h)
res = Image.new(im.mode, (width, height))
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
if ratio < src_ratio:
fill_height = height // 2 - src_h // 2
if width > 0 and fill_height > 0:
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
elif ratio > src_ratio:
fill_width = width // 2 - src_w // 2
if height > 0 and fill_width > 0:
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
return res
"""
ratio = min(width / im.width, height / im.height)
im = resize(im, int(im.width * ratio), int(im.height * ratio))
res = Image.new(im.mode, (width, height), color=color)
res.paste(im, box=((width - im.width)//2, (height - im.height)//2))
return res
def context_aware(im: Image.Image, width, height, context):
from installer import install
install('seam-carving')
width, height = int(width), int(height)
import seam_carving # https://github.com/li-plus/seam-carving
if 'forward' in context.lower():
energy_mode = "forward"
elif 'backward' in context.lower():
energy_mode = "backward"
else:
return im
if 'add' in context.lower():
src_ratio = min(width / im.width, height / im.height)
src_w = int(im.width * src_ratio)
src_h = int(im.height * src_ratio)
src_image = resize(im, src_w, src_h)
elif 'remove' in context.lower():
ratio = width / height
src_ratio = im.width / im.height
src_w = width if ratio > src_ratio else im.width * height // im.height
src_h = height if ratio <= src_ratio else im.height * width // im.width
src_image = resize(im, src_w, src_h)
else:
return im
np_image = seam_carving.resize(
src_image, # source image (rgb or gray)
size=(width, height), # target size
energy_mode=energy_mode, # choose from {backward, forward}
order="width-first", # choose from {width-first, height-first}
keep_mask=None, # object mask to protect from removal
)
res = Image.fromarray(np_image)
return res
t0 = time.time()
if resize_mode is None:
resize_mode = 0
if isinstance(im, torch.Tensor): # latent resize only supports fixed mode
res = resize(im, width, height)
return res
im = verify_image(im)
if not isinstance(im, Image.Image):
shared.log.error(f'Image resize: image={type(im)} invalid type')
return im
if (resize_mode == 0) or ((im.width == width) and (im.height == height)) or (width == 0 and height == 0): # none
res = im.copy()
elif resize_mode == 1: # fixed
res = resize(im, width, height)
elif resize_mode == 2: # crop
res = crop(im)
elif resize_mode == 3: # fill
res = fill(im)
elif resize_mode == 4: # edge
from modules import masking
res = fill(im, color=0)
res, _mask = masking.outpaint(res)
elif resize_mode == 5: # context-aware
res = context_aware(im, width, height, context)
else:
res = im.copy()
shared.log.error(f'Invalid resize mode: {resize_mode}')
t1 = time.time()
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
if im.width != width or im.height != height:
shared.log.debug(f'Image resize: source={im.width}:{im.height} target={width}:{height} mode="{shared.resize_modes[resize_mode]}" upscaler="{upscaler_name}" type={output_type} time={t1-t0:.2f} fn={fn}') # pylint: disable=protected-access
return np.array(res) if output_type == 'np' else res