mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
71 lines
2.0 KiB
Python
71 lines
2.0 KiB
Python
import logging
|
|
import PIL
|
|
import torch
|
|
from torchvision.transforms import ToPILImage, ToTensor
|
|
from .lbm import LBMModel
|
|
from modules import devices
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
ASPECT_RATIOS = {
|
|
str(512 / 2048): (512, 2048),
|
|
str(1024 / 1024): (1024, 1024),
|
|
str(2048 / 512): (2048, 512),
|
|
str(896 / 1152): (896, 1152),
|
|
str(1152 / 896): (1152, 896),
|
|
str(512 / 1920): (512, 1920),
|
|
str(640 / 1536): (640, 1536),
|
|
str(768 / 1280): (768, 1280),
|
|
str(1280 / 768): (1280, 768),
|
|
str(1536 / 640): (1536, 640),
|
|
str(1920 / 512): (1920, 512),
|
|
}
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(
|
|
model: LBMModel,
|
|
source_image: PIL.Image.Image,
|
|
num_sampling_steps: int = 1,
|
|
):
|
|
"""
|
|
Evaluate the model on an image coming from the source distribution and generate a new image from the target distribution.
|
|
|
|
Args:
|
|
model (LBMModel): The model to evaluate.
|
|
source_image (PIL.Image.Image): The source image to evaluate the model on.
|
|
num_sampling_steps (int): The number of sampling steps to use for the model.
|
|
|
|
Returns:
|
|
PIL.Image.Image: The generated image.
|
|
"""
|
|
|
|
ori_h_bg, ori_w_bg = source_image.size
|
|
ar_bg = ori_h_bg / ori_w_bg
|
|
closest_ar_bg = min(ASPECT_RATIOS, key=lambda x: abs(float(x) - ar_bg))
|
|
source_dimensions = ASPECT_RATIOS[closest_ar_bg]
|
|
|
|
source_image = source_image.resize(source_dimensions)
|
|
|
|
img_pasted_tensor = ToTensor()(source_image).unsqueeze(0) * 2 - 1
|
|
batch = {
|
|
"source_image": img_pasted_tensor.to(dtype=devices.dtype, device=devices.device),
|
|
}
|
|
|
|
z_source = model.vae.encode(batch[model.source_key])
|
|
|
|
output_image = model.sample(
|
|
z=z_source,
|
|
num_steps=num_sampling_steps,
|
|
conditioner_inputs=batch,
|
|
max_samples=1,
|
|
).clamp(-1, 1)
|
|
|
|
output_image = (output_image[0].float().cpu() + 1) / 2
|
|
output_image = ToPILImage()(output_image)
|
|
output_image.resize((ori_h_bg, ori_w_bg))
|
|
|
|
return output_image
|