1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/scripts/lbm/tiler.py
Vladimir Mandic 2b9056179d add lbm background replace with relightining
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-07-04 15:33:16 -04:00

378 lines
12 KiB
Python

import logging
import math
from copy import deepcopy
from typing import List, Tuple
import torch
import torch.nn.functional as F
TILING_METHODS = ["average", "gaussian", "linear"]
class Tiler:
def get_tiles(
self,
input: torch.Tensor,
tile_size: tuple,
overlap_size: tuple,
scale: int = 1,
out_channels: int = 3,
) -> List[List[torch.tensor]]:
"""Get tiles
Args:
input (torch.Tensor): input array of shape (batch_size, channels, height, width)
tile_size (tuple): tile size
overlap_size (tuple): overlap size
scale (int): scaling factor of the output wrt input
out_channels (int): number of output channels
Returns:
List[List[torch.Tensor]]: List of tiles
"""
# assert isinstance(scale, int)
assert (
overlap_size[0] <= tile_size[0]
), f"Overlap size {overlap_size} must be smaller than tile size {tile_size}"
assert (
overlap_size[1] <= tile_size[1]
), f"Overlap size {overlap_size} must be smaller than tile size {tile_size}"
B, C, H, W = input.shape
tile_size_H, tile_size_W = tile_size
# sets overlap to 0 if the input is smaller than the tile size (i.e. no overlap)
overlap_H, overlap_W = (
overlap_size[0] if H > tile_size_H else 0,
overlap_size[1] if W > tile_size_W else 0,
)
self.output_overlap_size = (
int(overlap_H * scale),
int(overlap_W * scale),
)
self.tile_size = tile_size
self.output_tile_size = (
int(tile_size_H * scale),
int(tile_size_W * scale),
)
self.output_shape = (
B,
out_channels,
int(H * scale),
int(W * scale),
)
tiles = []
logging.debug(f"(Tiler) Input shape: {(B, C, H, W)}")
logging.debug(f"(Tiler) Output shape: {self.output_shape}")
logging.debug(f"(Tiler) Tile size: {(tile_size_H, tile_size_W)}")
logging.debug(f"(Tiler) Overlap size: {(overlap_H, overlap_W)}")
# loop over all tiles in the image with overlap
for i in range(0, H, tile_size_H - overlap_H):
row = []
for j in range(0, W, tile_size_W - overlap_W):
tile = deepcopy(
input[
:,
:,
i : i + tile_size_H,
j : j + tile_size_W,
]
)
row.append(tile)
tiles.append(row)
return tiles
def merge_tiles(
self, tiles: List[List[torch.tensor]], tiling_method: str = "gaussian"
) -> torch.tensor:
"""Merge tiles by averaging the overlaping regions
Args:
tiles (Dict[str, Tile]): dictionary of processed tiles
tiling_method (str): tiling method. Can be "average", "gaussian" or "linear"
Returns:
torch.tensor: output image
"""
if tiling_method == "average":
return self._average_merge_tiles(tiles)
elif tiling_method == "gaussian":
return self._gaussian_merge_tiles(tiles)
elif tiling_method == "linear":
return self._linear_merge_tiles(tiles)
else:
raise ValueError(
f"Unknown tiling method {tiling_method}. Available methods are {TILING_METHODS}"
)
def _average_merge_tiles(self, tiles: List[List[torch.tensor]]) -> torch.tensor:
"""Merge tiles by averaging the overlaping regions
Args:
tiles (Dict[str, Tile]): dictionary of processed tiles
Returns:
torch.tensor: output image
"""
output = torch.zeros(self.output_shape)
# weights to store multiplicity
weights = torch.zeros(self.output_shape)
_, _, output_H, output_W = self.output_shape
output_overlap_size_H, output_overlap_size_W = self.output_overlap_size
output_tile_size_H, output_tile_size_W = self.output_tile_size
for id_i, i in enumerate(
range(
0,
output_H,
output_tile_size_H - output_overlap_size_H,
)
):
for id_j, j in enumerate(
range(
0,
output_W,
output_tile_size_W - output_overlap_size_W,
)
):
output[
:,
:,
i : i + output_tile_size_H,
j : j + output_tile_size_W,
] += (
tiles[id_i][id_j] * 1
)
weights[
:,
:,
i : i + output_tile_size_H,
j : j + output_tile_size_W,
] += 1
# outputs is summed up with this multiplicity
# so we need to divide by the weights wich is either 1, 2 or 4 depending on the region
output = output / weights
return output
def _gaussian_weights(
self, tile_width: int, tile_height: int, nbatches: int, channels: int
):
"""Generates a gaussian mask of weights for tile contributions.
Args:
tile_width (int): width of the tile
tile_height (int): height of the tile
nbatches (int): number of batches
channels (int): number of channels
Returns:
torch.tensor: weights
"""
import numpy as np
from numpy import exp, pi, sqrt
latent_width = tile_width
latent_height = tile_height
var = 0.01
midpoint = (
latent_width - 1
) / 2 # -1 because index goes from 0 to latent_width - 1
x_probs = [
exp(
-(x - midpoint)
* (x - midpoint)
/ (latent_width * latent_width)
/ (2 * var)
)
/ sqrt(2 * pi * var)
for x in range(latent_width)
]
midpoint = latent_height / 2
y_probs = [
exp(
-(y - midpoint)
* (y - midpoint)
/ (latent_height * latent_height)
/ (2 * var)
)
/ sqrt(2 * pi * var)
for y in range(latent_height)
]
weights = np.outer(y_probs, x_probs)
return torch.tile(
torch.tensor(weights, device="cpu"), (nbatches, channels, 1, 1)
)
def _gaussian_merge_tiles(self, tiles: List[List[torch.tensor]]) -> torch.tensor:
"""Merge tiles by averaging the overlaping regions
Args:
List[List[torch.tensor]]: List of processed tiles
Returns:
torch.tensor: output image
"""
B, output_C, output_H, output_W = self.output_shape
output_overlap_size_H, output_overlap_size_W = self.output_overlap_size
output_tile_size_H, output_tile_size_W = self.output_tile_size
output = torch.zeros(self.output_shape)
# weights to store multiplicity
weights = torch.zeros(self.output_shape)
for id_i, i in enumerate(
range(
0,
output_H,
output_tile_size_H - output_overlap_size_H,
)
):
for id_j, j in enumerate(
range(
0,
output_W,
output_tile_size_W - output_overlap_size_W,
)
):
w = self._gaussian_weights(
tiles[id_i][id_j].shape[3],
tiles[id_i][id_j].shape[2],
B,
output_C,
)
output[
:,
:,
i : i + output_tile_size_H,
j : j + output_tile_size_W,
] += (
tiles[id_i][id_j] * w
)
weights[
:,
:,
i : i + output_tile_size_H,
j : j + output_tile_size_W,
] += w
# outputs is summed up with this multiplicity
output = output / weights
return output
def _blend_v(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[
:, :, y, :
] * (y / blend_extent)
return b
def _blend_h(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[
:, :, :, x
] * (x / blend_extent)
return b
def _linear_merge_tiles(self, tiles: List[List[torch.tensor]]) -> torch.Tensor:
"""Merge tiles by blending the overlaping regions
Args:
tiles (List[List[torch.tensor]]): List of processed tiles
Returns:
torch.Tensor: output image
"""
output_overlap_size_H, output_overlap_size_W = self.output_overlap_size
output_tile_size_H, output_tile_size_W = self.output_tile_size
res_rows = []
tiles_copy = deepcopy(tiles)
# Cut the right and bottom overlap region
limit_i = output_tile_size_H - output_overlap_size_H
limit_j = output_tile_size_W - output_overlap_size_W
for i, tile_row in enumerate(tiles_copy):
res_row = []
for j, tile in enumerate(tile_row):
tile_val = tile
if j > 0:
tile_val = self._blend_h(
tile_row[j - 1], tile, output_overlap_size_W
)
tiles_copy[i][j] = tile_val
if i > 0:
tile_val = self._blend_v(
tiles_copy[i - 1][j], tile_val, output_overlap_size_H
)
tiles_copy[i][j] = tile_val
res_row.append(tile_val[:, :, :limit_i, :limit_j])
res_rows.append(torch.cat(res_row, dim=3))
output = torch.cat(res_rows, dim=2)
return output
def extract_into_tensor(
a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int, ...]
) -> torch.Tensor:
"""
Extracts values from a tensor into a new tensor using indices from another tensor.
:param a: the tensor to extract values from.
:param t: the tensor containing the indices.
:param x_shape: the shape of the tensor to extract values into.
:return: a new tensor containing the extracted values.
"""
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def pad(x: torch.Tensor, base_h: int, base_w: int) -> torch.Tensor:
"""
Pads a tensor to the nearest multiple of base_h and base_w.
:param x: the tensor to pad.
:param base_h: the base height.
:param base_w: the base width.
:return: the padded tensor.
"""
h, w = x.shape[-2:]
h_ = math.ceil(h / base_h) * base_h
w_ = math.ceil(w / base_w) * base_w
if w_ != w:
x = F.pad(x, (0, abs(w_ - w), 0, 0))
if h_ != h:
x = F.pad(x, (0, 0, 0, abs(h_ - h)))
return x
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
@torch.no_grad()
def update_ema(
target_params: List[torch.Tensor],
source_params: List[torch.Tensor],
rate: float = 0.99,
):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)