1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/face/reswapper.py
Vladimir Mandic f5a910f719 experimental reswapper
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-01-20 16:57:55 -05:00

112 lines
5.1 KiB
Python

from typing import List
import os
import cv2
import torch
import numpy as np
import huggingface_hub as hf
from PIL import Image
from modules import processing, shared, devices
RESWAPPER_REPO = 'somanchiu/reswapper'
RESWAPPER_MODELS = {
"ReSwapper 256 0.2": "reswapper_256-1567500.pth",
"ReSwapper 256 0.1": "reswapper_256-1399500.pth",
"ReSwapper 128 0.2": "reswapper-429500.pth",
"ReSwapper 128 0.1": "reswapper-1019500.pth",
}
reswapper_model = None
reswapper_name = None
debug = shared.log.trace if os.environ.get("SD_FACE_DEBUG", None) is not None else lambda *args, **kwargs: None
dtype = devices.dtype
def get_model(model_name: str):
global reswapper_model, reswapper_name # pylint: disable=global-statement
if reswapper_model is None or reswapper_name != model_name:
try:
fn = RESWAPPER_MODELS.get(model_name)
url = hf.hf_hub_download(repo_id=RESWAPPER_REPO, filename=fn, repo_type="model", cache_dir=shared.opts.hfcache_dir)
from modules.face.reswapper_model import ReSwapperModel
reswapper_model = ReSwapperModel()
reswapper_model.load_state_dict(torch.load(url, map_location='cpu'), strict=False)
reswapper_model = reswapper_model.to(device=devices.device, dtype=dtype)
reswapper_model.eval()
reswapper_name = model_name
shared.log.info(f'ReSwapper: model="{model_name}" url="{url}" cls={reswapper_model.__class__.__name__}')
if reswapper_model is None:
shared.log.error(f'ReSwapper: model="{model_name}" fn="{fn}" url="{url}" failed to load model')
return reswapper_model
except Exception as e:
shared.log.error(f'ReSwapper: model="{model_name}" fn="{fn}" url="{url}" {e}')
return reswapper_model
def reswapper(
p: processing.StableDiffusionProcessing,
app,
source_images: List[Image.Image],
target_images: List[Image.Image],
model_name: str,
original: bool,
):
from modules.face import reswapper_utils as utils
if source_images is None or len(source_images) == 0:
shared.log.warning('ReSwapper: no input images')
return None
processed_images = []
if original:
processed_images += source_images
model = get_model(model_name)
if model is None:
return source_images
model = model.to(device=devices.device)
i = 0
for x, image in enumerate(source_images):
image = image.convert('RGB')
source_np = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
source_faces = app.get(source_np)
if len(source_faces) == 0:
shared.log.error(f"ReSwapper: image={x+1} no source faces found")
return source_images
if len(source_faces) != len(target_images):
shared.log.warning(f"ReSwapper: image={x+1} source-faces={len(source_faces)} target-images={len(target_images)}")
for y, source_face in enumerate(source_faces):
target_image = target_images[y] if y < len(target_images) else target_images[-1]
target_image = target_image.convert('RGB')
target_np = cv2.cvtColor(np.array(target_image), cv2.COLOR_RGB2BGR)
target_faces = app.get(target_np)
if len(target_faces) != 1:
shared.log.error(f"ReSwapper: image={x+1} source-faces={y+1} target-faces={len(target_faces)} must be exactly one")
return source_images
target_face = target_faces[0]
source_str = f'score:{source_face.det_score:.2f} gender:{"female" if source_face.gender==0 else "male"} age:{source_face.age}'
target_str = f'score:{target_face.det_score:.2f} gender:{"female" if target_face.gender==0 else "male"} age:{target_face.age}'
shared.log.debug(f'ReSwapper image={x+1} face={y+1} source="{source_str}" target="{target_str}"')
source_latent = utils.getLatent(source_face)
source_tensor = torch.from_numpy(source_latent).to(device=devices.device, dtype=dtype)
resolution = 256 if '256' in model_name else 128
target_np = cv2.cvtColor(np.array(target_image), cv2.COLOR_RGB2BGR)
target_aligned, M = utils.norm_crop2(target_np, target_face.kps, resolution)
target_blob = utils.getBlob(target_aligned, (resolution, resolution))
target_tensor = torch.from_numpy(target_blob).to(device=devices.device, dtype=dtype)
with devices.inference_context():
swapped_tensor = model(target_tensor, source_tensor)
swapped_tensor = swapped_tensor.float()
swapped_face = (swapped_tensor.squeeze().permute(1, 2, 0).cpu().detach().numpy() * 255).astype(np.uint8)
swapped_face = cv2.cvtColor(swapped_face, cv2.COLOR_RGB2BGR)
swapped_np = utils.blend_swapped_image(swapped_face, source_np, M)
swapped_image = Image.fromarray(cv2.cvtColor(swapped_np, cv2.COLOR_BGR2RGB))
processed_images.append(swapped_image)
i += 1
p.extra_generation_params['ReSwapper'] = f'faces={i}'
devices.torch_gc()
return processed_images