mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
220 lines
9.7 KiB
Python
220 lines
9.7 KiB
Python
import math
|
|
from collections import namedtuple
|
|
import numpy as np
|
|
from PIL import Image, ImageFont, ImageDraw
|
|
from modules import shared, script_callbacks
|
|
|
|
|
|
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
|
|
|
|
|
|
def check_grid_size(imgs):
|
|
if imgs is None or len(imgs) == 0:
|
|
return False
|
|
mp = 0
|
|
for img in imgs:
|
|
if isinstance(img, list):
|
|
for im in img:
|
|
mp += im.width * im.height if im is not None else 0
|
|
else:
|
|
mp += img.width * img.height if img is not None else 0
|
|
mp = round(mp / 1000000)
|
|
ok = mp <= shared.opts.img_max_size_mp
|
|
if not ok:
|
|
shared.log.warning(f'Maximum image size exceded: size={mp} maximum={shared.opts.img_max_size_mp} MPixels')
|
|
return ok
|
|
|
|
|
|
def get_grid_size(imgs, batch_size=1, rows=None, cols=None):
|
|
if rows and rows > len(imgs):
|
|
rows = len(imgs)
|
|
if cols and cols > len(imgs):
|
|
cols = len(imgs)
|
|
if rows is None and cols is None:
|
|
if shared.opts.n_rows > 0:
|
|
rows = shared.opts.n_rows
|
|
cols = math.ceil(len(imgs) / rows)
|
|
elif shared.opts.n_rows == 0:
|
|
rows = batch_size
|
|
cols = math.ceil(len(imgs) / rows)
|
|
elif shared.opts.n_cols > 0:
|
|
cols = shared.opts.n_cols
|
|
rows = math.ceil(len(imgs) / cols)
|
|
elif shared.opts.n_cols == 0:
|
|
cols = batch_size
|
|
rows = math.ceil(len(imgs) / cols)
|
|
else:
|
|
rows = math.floor(math.sqrt(len(imgs)))
|
|
while len(imgs) % rows != 0:
|
|
rows -= 1
|
|
cols = math.ceil(len(imgs) / rows)
|
|
elif cols is None:
|
|
cols = math.ceil(len(imgs) / rows)
|
|
elif rows is None:
|
|
rows = math.ceil(len(imgs) / cols)
|
|
else:
|
|
pass
|
|
return rows, cols
|
|
|
|
|
|
def image_grid(imgs, batch_size:int=1, rows:int=None, cols:int=None):
|
|
rows, cols = get_grid_size(imgs, batch_size, rows=rows, cols=cols)
|
|
params = script_callbacks.ImageGridLoopParams(imgs, cols, rows)
|
|
script_callbacks.image_grid_callback(params)
|
|
imgs = [i for i in imgs if i is not None] if imgs is not None else []
|
|
if len(imgs) == 0:
|
|
return None
|
|
w, h = max(i.width for i in imgs if i is not None), max(i.height for i in imgs if i is not None)
|
|
grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color=shared.opts.grid_background)
|
|
for i, img in enumerate(params.imgs):
|
|
if img is not None:
|
|
grid.paste(img, box=(i % params.cols * w, i // params.cols * h))
|
|
return grid
|
|
|
|
|
|
def split_grid(image, tile_w=512, tile_h=512, overlap=64):
|
|
w = image.width
|
|
h = image.height
|
|
non_overlap_width = tile_w - overlap
|
|
non_overlap_height = tile_h - overlap
|
|
cols = math.ceil((w - overlap) / non_overlap_width)
|
|
rows = math.ceil((h - overlap) / non_overlap_height)
|
|
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
|
|
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
|
|
grid = Grid([], tile_w, tile_h, w, h, overlap)
|
|
for row in range(rows):
|
|
row_images = []
|
|
y = int(row * dy)
|
|
if y + tile_h >= h:
|
|
y = h - tile_h
|
|
for col in range(cols):
|
|
x = int(col * dx)
|
|
if x + tile_w >= w:
|
|
x = w - tile_w
|
|
tile = image.crop((x, y, x + tile_w, y + tile_h))
|
|
row_images.append([x, tile_w, tile])
|
|
grid.tiles.append([y, tile_h, row_images])
|
|
return grid
|
|
|
|
|
|
def combine_grid(grid):
|
|
def make_mask_image(r):
|
|
r = r * 255 / grid.overlap
|
|
r = r.astype(np.uint8)
|
|
return Image.fromarray(r, 'L')
|
|
|
|
mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
|
|
mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
|
|
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
|
|
for y, h, row in grid.tiles:
|
|
combined_row = Image.new("RGB", (grid.image_w, h))
|
|
for x, w, tile in row:
|
|
if x == 0:
|
|
combined_row.paste(tile, (0, 0))
|
|
continue
|
|
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
|
|
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
|
|
if y == 0:
|
|
combined_image.paste(combined_row, (0, 0))
|
|
continue
|
|
combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
|
|
combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
|
|
return combined_image
|
|
|
|
|
|
class GridAnnotation:
|
|
def __init__(self, text='', is_active=True):
|
|
self.text = str(text)
|
|
self.is_active = is_active
|
|
self.size = None
|
|
|
|
|
|
def get_font(fontsize):
|
|
try:
|
|
return ImageFont.truetype(shared.opts.font or "javascript/notosans-nerdfont-regular.ttf", fontsize)
|
|
except Exception:
|
|
return ImageFont.truetype("javascript/notosans-nerdfont-regular.ttf", fontsize)
|
|
|
|
|
|
def draw_grid_annotations(im, width, height, x_texts, y_texts, margin=0, title=None):
|
|
def wrap(drawing, text, font, line_length):
|
|
lines = ['']
|
|
for word in text.split():
|
|
line = f'{lines[-1]} {word}'.strip()
|
|
if drawing.textlength(line, font=font) <= line_length:
|
|
lines[-1] = line
|
|
else:
|
|
lines.append(word)
|
|
return lines
|
|
|
|
def draw_texts(drawing: ImageDraw, draw_x, draw_y, lines, initial_fnt, initial_fontsize):
|
|
for line in lines:
|
|
font = initial_fnt
|
|
fontsize = initial_fontsize
|
|
while drawing.multiline_textbbox((0,0), text=line.text, font=font)[2] > line.allowed_width and fontsize > 0:
|
|
fontsize -= 1
|
|
font = get_font(fontsize)
|
|
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=font, fill=shared.opts.font_color if line.is_active else color_inactive, anchor="mm", align="center")
|
|
if not line.is_active:
|
|
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
|
|
draw_y += line.size[1] + line_spacing
|
|
|
|
fontsize = (width + height) // 25
|
|
line_spacing = fontsize // 2
|
|
font = get_font(fontsize)
|
|
color_inactive = (127, 127, 127)
|
|
pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in y_texts]) == 0 else width * 3 // 4
|
|
cols = len(x_texts)
|
|
rows = len(y_texts)
|
|
# assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
|
|
# assert rows == len(hor_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
|
|
calc_img = Image.new("RGB", (1, 1), shared.opts.grid_background)
|
|
calc_d = ImageDraw.Draw(calc_img)
|
|
title_texts = [title] if title else [[GridAnnotation()]]
|
|
for texts, allowed_width in zip(x_texts + y_texts + title_texts, [width] * len(x_texts) + [pad_left] * len(y_texts) + [(width+margin)*cols]):
|
|
items = [] + texts
|
|
texts.clear()
|
|
for line in items:
|
|
wrapped = wrap(calc_d, line.text, font, allowed_width)
|
|
texts += [GridAnnotation(x, line.is_active) for x in wrapped]
|
|
for line in texts:
|
|
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=font)
|
|
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
|
|
line.allowed_width = allowed_width
|
|
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in x_texts]
|
|
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in y_texts]
|
|
pad_top = 0 if sum(hor_text_heights) == 0 else max(hor_text_heights) + line_spacing * 2
|
|
title_pad = 0
|
|
if title:
|
|
title_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in title_texts] # pylint: disable=unsubscriptable-object
|
|
title_pad = 0 if sum(title_text_heights) == 0 else max(title_text_heights) + line_spacing * 2
|
|
result = Image.new("RGB", (im.width + pad_left + margin * (cols-1), im.height + pad_top + title_pad + margin * (rows-1)), shared.opts.grid_background)
|
|
for row in range(rows):
|
|
for col in range(cols):
|
|
cell = im.crop((width * col, height * row, width * (col+1), height * (row+1)))
|
|
result.paste(cell, (pad_left + (width + margin) * col, pad_top + title_pad + (height + margin) * row))
|
|
d = ImageDraw.Draw(result)
|
|
if title:
|
|
x = pad_left + ((width+margin)*cols) / 2
|
|
y = title_pad / 2 - title_text_heights[0] / 2
|
|
draw_texts(d, x, y, title_texts[0], font, fontsize)
|
|
for col in range(cols):
|
|
x = pad_left + (width + margin) * col + width / 2
|
|
y = (pad_top / 2 - hor_text_heights[col] / 2) + title_pad
|
|
draw_texts(d, x, y, x_texts[col], font, fontsize)
|
|
for row in range(rows):
|
|
x = pad_left / 2
|
|
y = (pad_top + (height + margin) * row + height / 2 - ver_text_heights[row] / 2) + title_pad
|
|
draw_texts(d, x, y, y_texts[row], font, fontsize)
|
|
return result
|
|
|
|
|
|
def draw_prompt_matrix(im, width, height, all_prompts, margin=0):
|
|
prompts = all_prompts[1:]
|
|
boundary = math.ceil(len(prompts) / 2)
|
|
prompts_horiz = prompts[:boundary]
|
|
prompts_vert = prompts[boundary:]
|
|
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
|
|
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
|
|
return draw_grid_annotations(im, width, height, hor_texts, ver_texts, margin)
|