mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-29 05:02:09 +03:00
170 lines
7.8 KiB
Python
170 lines
7.8 KiB
Python
import functools
|
|
import os
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from einops import rearrange
|
|
from huggingface_hub import hf_hub_download
|
|
from PIL import Image
|
|
from modules import devices
|
|
from modules.shared import opts
|
|
from modules.control.util import HWC3, resize_image
|
|
|
|
|
|
class UnetGenerator(nn.Module):
|
|
"""Create a Unet-based generator"""
|
|
|
|
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
"""Construct a Unet generator
|
|
Parameters:
|
|
input_nc (int) -- the number of channels in input images
|
|
output_nc (int) -- the number of channels in output images
|
|
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
|
image of size 128x128 will become of size 1x1 # at the bottleneck
|
|
ngf (int) -- the number of filters in the last conv layer
|
|
norm_layer -- normalization layer
|
|
We construct the U-Net from the innermost layer to the outermost layer.
|
|
It is a recursive process.
|
|
"""
|
|
super(UnetGenerator, self).__init__()
|
|
# construct unet structure
|
|
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
|
for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
|
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
|
# gradually reduce the number of filters from ngf * 8 to ngf
|
|
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
|
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
|
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
|
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
|
|
|
def forward(self, input): # pylint: disable=redefined-builtin
|
|
"""Standard forward"""
|
|
return self.model(input)
|
|
|
|
|
|
class UnetSkipConnectionBlock(nn.Module):
|
|
"""Defines the Unet submodule with skip connection.
|
|
X -------------------identity----------------------
|
|
|-- downsampling -- |submodule| -- upsampling --|
|
|
"""
|
|
|
|
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
|
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
"""Construct a Unet submodule with skip connections.
|
|
Parameters:
|
|
outer_nc (int) -- the number of filters in the outer conv layer
|
|
inner_nc (int) -- the number of filters in the inner conv layer
|
|
input_nc (int) -- the number of channels in input images/features
|
|
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
|
outermost (bool) -- if this module is the outermost module
|
|
innermost (bool) -- if this module is the innermost module
|
|
norm_layer -- normalization layer
|
|
use_dropout (bool) -- if use dropout layers.
|
|
"""
|
|
super(UnetSkipConnectionBlock, self).__init__()
|
|
self.outermost = outermost
|
|
if type(norm_layer) == functools.partial:
|
|
use_bias = norm_layer.func == nn.InstanceNorm2d
|
|
else:
|
|
use_bias = norm_layer == nn.InstanceNorm2d
|
|
if input_nc is None:
|
|
input_nc = outer_nc
|
|
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
|
stride=2, padding=1, bias=use_bias)
|
|
downrelu = nn.LeakyReLU(0.2, True)
|
|
downnorm = norm_layer(inner_nc)
|
|
uprelu = nn.ReLU(True)
|
|
upnorm = norm_layer(outer_nc)
|
|
|
|
if outermost:
|
|
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
|
kernel_size=4, stride=2,
|
|
padding=1)
|
|
down = [downconv]
|
|
up = [uprelu, upconv, nn.Tanh()]
|
|
model = down + [submodule] + up
|
|
elif innermost:
|
|
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
|
kernel_size=4, stride=2,
|
|
padding=1, bias=use_bias)
|
|
down = [downrelu, downconv]
|
|
up = [uprelu, upconv, upnorm]
|
|
model = down + up
|
|
else:
|
|
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
|
kernel_size=4, stride=2,
|
|
padding=1, bias=use_bias)
|
|
down = [downrelu, downconv, downnorm]
|
|
up = [uprelu, upconv, upnorm]
|
|
|
|
if use_dropout:
|
|
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
|
else:
|
|
model = down + [submodule] + up
|
|
|
|
self.model = nn.Sequential(*model)
|
|
|
|
def forward(self, x):
|
|
if self.outermost:
|
|
return self.model(x)
|
|
else: # add skip connections
|
|
return torch.cat([x, self.model(x)], 1)
|
|
|
|
|
|
class LineartAnimeDetector:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False):
|
|
filename = filename or "netG.pth"
|
|
if os.path.isdir(pretrained_model_or_path):
|
|
model_path = os.path.join(pretrained_model_or_path, filename)
|
|
else:
|
|
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only)
|
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
|
net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
|
|
ckpt = torch.load(model_path)
|
|
for key in list(ckpt.keys()):
|
|
if 'module.' in key:
|
|
ckpt[key.replace('module.', '')] = ckpt[key]
|
|
del ckpt[key]
|
|
net.load_state_dict(ckpt)
|
|
net.eval()
|
|
return cls(net)
|
|
|
|
def to(self, device):
|
|
self.model.to(device)
|
|
return self
|
|
|
|
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
|
|
self.model.to(devices.device)
|
|
device = next(iter(self.model.parameters())).device
|
|
if not isinstance(input_image, np.ndarray):
|
|
input_image = np.array(input_image, dtype=np.uint8)
|
|
input_image = HWC3(input_image)
|
|
input_image = resize_image(input_image, detect_resolution)
|
|
H, W, _C = input_image.shape
|
|
Hn = 256 * int(np.ceil(float(H) / 256.0))
|
|
Wn = 256 * int(np.ceil(float(W) / 256.0))
|
|
img = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_CUBIC)
|
|
image_feed = torch.from_numpy(img).float().to(device)
|
|
image_feed = image_feed / 127.5 - 1.0
|
|
image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
|
|
line = self.model(image_feed)[0, 0] * 127.5 + 127.5
|
|
line = line.cpu().numpy()
|
|
line = cv2.resize(line, (W, H), interpolation=cv2.INTER_CUBIC)
|
|
line = line.clip(0, 255).astype(np.uint8)
|
|
detected_map = line
|
|
detected_map = HWC3(detected_map)
|
|
img = resize_image(input_image, image_resolution)
|
|
H, W, _C = img.shape
|
|
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
|
detected_map = 255 - detected_map
|
|
if opts.control_move_processor:
|
|
self.model.to('cpu')
|
|
if output_type == "pil":
|
|
detected_map = Image.fromarray(detected_map)
|
|
return detected_map
|