mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.transforms.functional as F
|
|
|
|
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
|
CenterCrop
|
|
|
|
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
|
|
|
|
|
class ResizeMaxSize(nn.Module):
|
|
|
|
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
|
|
super().__init__()
|
|
if not isinstance(max_size, int):
|
|
raise TypeError(f"Size should be int. Got {type(max_size)}")
|
|
self.max_size = max_size
|
|
self.interpolation = interpolation
|
|
self.fn = min if fn == 'min' else min
|
|
self.fill = fill
|
|
|
|
def forward(self, img):
|
|
if isinstance(img, torch.Tensor):
|
|
height, width = img.shape[:2]
|
|
else:
|
|
width, height = img.size
|
|
scale = self.max_size / float(max(height, width))
|
|
if scale != 1.0:
|
|
new_size = tuple(round(dim * scale) for dim in (height, width))
|
|
img = F.resize(img, new_size, self.interpolation)
|
|
pad_h = self.max_size - new_size[0]
|
|
pad_w = self.max_size - new_size[1]
|
|
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
|
|
return img
|
|
|
|
|
|
def _convert_to_rgb(image):
|
|
return image.convert('RGB')
|
|
|
|
|
|
# class CatGen(nn.Module):
|
|
# def __init__(self, num=4):
|
|
# self.num = num
|
|
# def mixgen_batch(image, text):
|
|
# batch_size = image.shape[0]
|
|
# index = np.random.permutation(batch_size)
|
|
|
|
# cat_images = []
|
|
# for i in range(batch_size):
|
|
# # image mixup
|
|
# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
|
|
# # text concat
|
|
# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
|
|
# text = torch.stack(text)
|
|
# return image, text
|
|
|
|
|
|
def image_transform(
|
|
image_size: int,
|
|
is_train: bool,
|
|
mean: Optional[Tuple[float, ...]] = None,
|
|
std: Optional[Tuple[float, ...]] = None,
|
|
resize_longest_max: bool = False,
|
|
fill_color: int = 0,
|
|
):
|
|
mean = mean or OPENAI_DATASET_MEAN
|
|
if not isinstance(mean, (list, tuple)):
|
|
mean = (mean,) * 3
|
|
|
|
std = std or OPENAI_DATASET_STD
|
|
if not isinstance(std, (list, tuple)):
|
|
std = (std,) * 3
|
|
|
|
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
|
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
|
image_size = image_size[0]
|
|
|
|
normalize = Normalize(mean=mean, std=std)
|
|
if is_train:
|
|
return Compose([
|
|
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
|
|
_convert_to_rgb,
|
|
ToTensor(),
|
|
normalize,
|
|
])
|
|
else:
|
|
if resize_longest_max:
|
|
transforms = [
|
|
ResizeMaxSize(image_size, fill=fill_color)
|
|
]
|
|
else:
|
|
transforms = [
|
|
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
|
CenterCrop(image_size),
|
|
]
|
|
transforms.extend([
|
|
_convert_to_rgb,
|
|
ToTensor(),
|
|
normalize,
|
|
])
|
|
return Compose(transforms)
|