mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
92 lines
3.3 KiB
Python
92 lines
3.3 KiB
Python
import os
|
|
import types
|
|
import cv2
|
|
import numpy as np
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
from einops import rearrange
|
|
from huggingface_hub import hf_hub_download
|
|
from PIL import Image
|
|
from modules import devices
|
|
from modules.shared import opts
|
|
from modules.control.util import HWC3, resize_image
|
|
from .nets.NNET import NNET
|
|
|
|
|
|
# load model
|
|
def load_checkpoint(fpath, model):
|
|
ckpt = torch.load(fpath, map_location='cpu')['model']
|
|
|
|
load_dict = {}
|
|
for k, v in ckpt.items():
|
|
if k.startswith('module.'):
|
|
k_ = k.replace('module.', '')
|
|
load_dict[k_] = v
|
|
else:
|
|
load_dict[k] = v
|
|
model.load_state_dict(load_dict)
|
|
return model
|
|
|
|
class NormalBaeDetector:
|
|
def __init__(self, model):
|
|
self.model = model
|
|
self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None, local_files_only=False):
|
|
filename = filename or "scannet.pt"
|
|
if os.path.isdir(pretrained_model_or_path):
|
|
model_path = os.path.join(pretrained_model_or_path, filename)
|
|
else:
|
|
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir, local_files_only=local_files_only)
|
|
args = types.SimpleNamespace()
|
|
args.mode = 'client'
|
|
args.architecture = 'BN'
|
|
args.pretrained = 'scannet'
|
|
args.sampling_ratio = 0.4
|
|
args.importance_ratio = 0.7
|
|
model = NNET(args)
|
|
model = load_checkpoint(model_path, model)
|
|
model.eval()
|
|
return cls(model)
|
|
|
|
def to(self, device):
|
|
self.model.to(device)
|
|
return self
|
|
|
|
|
|
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
|
|
self.model.to(devices.device)
|
|
device = next(iter(self.model.parameters())).device
|
|
if not isinstance(input_image, np.ndarray):
|
|
input_image = np.array(input_image, dtype=np.uint8)
|
|
|
|
input_image = HWC3(input_image)
|
|
input_image = resize_image(input_image, detect_resolution)
|
|
|
|
assert input_image.ndim == 3
|
|
image_normal = input_image
|
|
image_normal = torch.from_numpy(image_normal).float().to(device)
|
|
image_normal = image_normal / 255.0
|
|
image_normal = rearrange(image_normal, 'h w c -> 1 c h w')
|
|
image_normal = self.norm(image_normal)
|
|
|
|
normal = self.model(image_normal)
|
|
normal = normal[0][-1][:, :3]
|
|
# d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5
|
|
# d = torch.maximum(d, torch.ones_like(d) * 1e-5)
|
|
# normal /= d
|
|
normal = ((normal + 1) * 0.5).clip(0, 1)
|
|
normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy()
|
|
normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8)
|
|
detected_map = normal_image
|
|
detected_map = HWC3(detected_map)
|
|
img = resize_image(input_image, image_resolution)
|
|
H, W, _C = img.shape
|
|
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
|
if output_type == "pil":
|
|
detected_map = Image.fromarray(detected_map)
|
|
if opts.control_move_processor:
|
|
self.model.to('cpu')
|
|
return detected_map
|