You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-28 12:20:55 +03:00
Allows using full ref with lite ip adapter, or full ref alone without loading the ip weights
255 lines
9.9 KiB
Python
255 lines
9.9 KiB
Python
import os
|
|
import torch
|
|
from ..utils import log
|
|
import numpy as np
|
|
|
|
import comfy.model_management as mm
|
|
from comfy.utils import load_torch_file
|
|
import folder_paths
|
|
|
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
|
device = mm.get_torch_device()
|
|
offload_device = mm.unet_offload_device()
|
|
|
|
from .resampler import Resampler
|
|
|
|
class LoadLynxResampler:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"model_name": (folder_paths.get_filename_list("diffusion_models"), {"tooltip": "These models are loaded from 'ComfyUI/models/diffusion_models'"}),
|
|
"precision": (["fp32", "bf16", "fp16"], {"default": "fp16"}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("LYNXRESAMPLER",)
|
|
RETURN_NAMES = ("resampler", )
|
|
FUNCTION = "loadmodel"
|
|
CATEGORY = "WanVideoWrapper"
|
|
|
|
def loadmodel(self, model_name, precision):
|
|
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
|
|
|
model_path = folder_paths.get_full_path("diffusion_models", model_name)
|
|
resampler_sd = load_torch_file(model_path, safe_load=True)
|
|
|
|
output_dim = resampler_sd["proj_out.weight"].shape[0]
|
|
|
|
resampler = Resampler(
|
|
depth=4,
|
|
dim=1280,
|
|
dim_head=64,
|
|
embedding_dim=512,
|
|
ff_mult=4,
|
|
heads=20,
|
|
num_queries=16,
|
|
output_dim=output_dim,
|
|
dtype=dtype,
|
|
).eval()
|
|
resampler.to(offload_device, dtype)
|
|
resampler.load_state_dict(resampler_sd, strict=True)
|
|
|
|
return resampler,
|
|
|
|
|
|
class LynxInsightFaceCrop:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE", {"tooltip": "Input images for the model"}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE", "IMAGE",)
|
|
RETURN_NAMES = ("ip_image", "ref_image")
|
|
FUNCTION = "encode"
|
|
CATEGORY = "WanVideoWrapper"
|
|
|
|
def encode(self, image, image_size=112):
|
|
from .face.face_encoder import get_landmarks_from_image
|
|
from .face.face_utils import align_face
|
|
from insightface.utils import face_align
|
|
|
|
image_np = (image[0].numpy() * 255).astype(np.uint8)
|
|
landmarks = get_landmarks_from_image(image_np)
|
|
|
|
in_image = np.array(image_np)
|
|
landmark = np.array(landmarks)
|
|
|
|
ip_face_aligned = face_align.norm_crop(in_image, landmark=landmark, image_size=112)
|
|
ref_face_aligned = align_face(in_image, landmark, extend_face_crop=True, face_size=256)
|
|
|
|
ip_face_aligned = torch.from_numpy(ip_face_aligned).unsqueeze(0).float() / 255.0
|
|
ref_face_aligned = torch.from_numpy(ref_face_aligned).unsqueeze(0).float() / 255.0
|
|
|
|
ip_face_aligned = (ip_face_aligned - ip_face_aligned.min()) / (ip_face_aligned.max() - ip_face_aligned.min())
|
|
ref_face_aligned = (ref_face_aligned - ref_face_aligned.min()) / (ref_face_aligned.max() - ref_face_aligned.min())
|
|
ref_face_aligned = ref_face_aligned[:, :, :, [2, 1, 0]] # BGR to RGB
|
|
|
|
return ip_face_aligned, ref_face_aligned
|
|
|
|
|
|
class LynxEncodeFaceIP:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"resampler": ("LYNXRESAMPLER", {"tooltip": "lynx resampler model"}),
|
|
"ip_image": ("IMAGE", {"tooltip": "Input images for the model"}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("LYNXIP",)
|
|
RETURN_NAMES = ("lynx_face_embeds",)
|
|
FUNCTION = "encode"
|
|
CATEGORY = "WanVideoWrapper"
|
|
|
|
def encode(self, resampler, ip_image):
|
|
from .face.face_encoder import FaceEncoderArcFace
|
|
|
|
image_in = ip_image.permute(0, 3, 1, 2).to(device) * 2 - 1 # to [-1, 1]
|
|
|
|
# Face embedding via ArcFace
|
|
face_encoder = FaceEncoderArcFace()
|
|
face_encoder.init_encoder_model(device)
|
|
arcface_embed = face_encoder(image_in).to(device, resampler.dtype)[0]
|
|
|
|
arcface_embed = arcface_embed.reshape([1, -1, 512])
|
|
|
|
resampler.to(device)
|
|
ip_x = resampler(arcface_embed)
|
|
ip_x_uncond = resampler(arcface_embed * 0)
|
|
resampler.to(offload_device)
|
|
|
|
ip_x= ip_x.to(resampler.dtype)
|
|
|
|
out_dict = {
|
|
'ip_x': ip_x,
|
|
'ip_x_uncond': ip_x_uncond,
|
|
}
|
|
|
|
return out_dict,
|
|
|
|
class DrawArcFaceLandmarks:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"lynx_face_embeds": ("LYNXIP", {"tooltip": "lynx resampler model"}),
|
|
"image": ("IMAGE", {"tooltip": "Input images for the model"}),
|
|
},
|
|
"optional": {
|
|
"image": ("IMAGE",)
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("landmarked_image", )
|
|
FUNCTION = "draw"
|
|
CATEGORY = "WanVideoWrapper"
|
|
DESCRIPTION = "Draw face landmarks on an image for visualization/debugging"
|
|
|
|
def draw(self, lynx_face_embeds, image):
|
|
import cv2
|
|
landmarks = lynx_face_embeds['landmarks']
|
|
image_np = image[0].numpy() * 255
|
|
|
|
for (x, y) in landmarks:
|
|
cv2.circle(image_np, (int(x), int(y)), radius=3, color=(0, 255, 0), thickness=-1)
|
|
|
|
image_out = torch.from_numpy(image_np / 255).unsqueeze(0).float()
|
|
|
|
return image_out,
|
|
|
|
class WanVideoAddLynxEmbeds:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"embeds": ("WANVIDIMAGE_EMBEDS",),
|
|
"ip_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the ip adapter face feature"}),
|
|
"ref_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "tooltip": "Strength of the reference feature"}),
|
|
"lynx_cfg_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "If above 1.0 and main cfg_scale is above 1.0, run extra pass, default value 2.0"}),
|
|
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percent to apply the ref "}),
|
|
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percent to apply the ref "}),
|
|
},
|
|
"optional": {
|
|
"vae": ("WANVAE", {"tooltip": "VAE model, only needed if ref_image is provided"}),
|
|
"lynx_ip_embeds": ("LYNXIP", {"tooltip": "lynx face embeddings"}),
|
|
"ref_image": ("IMAGE",),
|
|
"ref_text_embed": ("WANVIDEOTEXTEMBEDS",),
|
|
"ref_blocks_to_use": ("STRING", {"default": "", "forceInput": True, "tooltip": "Comma-separated list of block indices and ranges to use for reference feature, e.g. '0-20, 25, 28, 35-39'. If empty, use all blocks."}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
|
|
RETURN_NAMES = ("image_embeds",)
|
|
FUNCTION = "add"
|
|
CATEGORY = "WanVideoWrapper"
|
|
|
|
def add(self, embeds, ip_scale, ref_scale, start_percent, end_percent, lynx_cfg_scale, vae=None, lynx_ip_embeds=None, ref_image=None, ref_text_embed=None, ref_blocks_to_use=""):
|
|
if ref_image is not None and ref_text_embed is None:
|
|
raise ValueError("If ref_image is provided, ref_text_embed must also be provided.")
|
|
if ref_image is not None:
|
|
vae.to(device)
|
|
ref_image_in = (ref_image[..., :3].permute(3, 0, 1, 2) * 2 - 1).to(device, vae.dtype)
|
|
ref_latent = vae.encode([ref_image_in], device, tiled=False, sample=True)
|
|
ref_latent_uncond = vae.encode([torch.zeros_like(ref_image_in)], device, tiled=False, sample=True)
|
|
vae.to(offload_device)
|
|
if ref_blocks_to_use.strip() == "":
|
|
ref_blocks_to_use = None
|
|
else:
|
|
# Parse comma-separated blocks and ranges
|
|
blocks = []
|
|
for item in ref_blocks_to_use.split(","):
|
|
item = item.strip()
|
|
if "-" in item and not item.startswith("-"):
|
|
# Handle range like "0-20" or "35-39"
|
|
try:
|
|
start, end = item.split("-", 1)
|
|
start, end = int(start.strip()), int(end.strip())
|
|
blocks.extend(list(range(start, end + 1)))
|
|
except ValueError:
|
|
print(f"Invalid range format: {item}")
|
|
elif item.isdigit():
|
|
# Handle single number
|
|
blocks.append(int(item))
|
|
else:
|
|
print(f"Invalid block specification: {item}")
|
|
ref_blocks_to_use = sorted(list(set(blocks))) # Remove duplicates and sort
|
|
print("Using ref blocks:", ref_blocks_to_use)
|
|
|
|
new_entry = {
|
|
"ip_x": lynx_ip_embeds["ip_x"] if lynx_ip_embeds is not None else None,
|
|
"ip_x_uncond": lynx_ip_embeds["ip_x_uncond"] if lynx_ip_embeds is not None else None,
|
|
"ref_latent": ref_latent if ref_image is not None else None,
|
|
"ref_latent_uncond": ref_latent_uncond if ref_image is not None else None,
|
|
"ref_text_embed": ref_text_embed if ref_text_embed is not None else None,
|
|
"ip_scale": ip_scale,
|
|
"ref_scale": ref_scale,
|
|
"cfg_scale": lynx_cfg_scale,
|
|
"start_percent": start_percent,
|
|
"end_percent": end_percent,
|
|
"ref_blocks_to_use": ref_blocks_to_use,
|
|
}
|
|
|
|
updated = dict(embeds)
|
|
updated["lynx_embeds"] = new_entry
|
|
return (updated,)
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"LoadLynxResampler": LoadLynxResampler,
|
|
"LynxEncodeFaceIP": LynxEncodeFaceIP,
|
|
"DrawArcFaceLandmarks": DrawArcFaceLandmarks,
|
|
"WanVideoAddLynxEmbeds": WanVideoAddLynxEmbeds,
|
|
"LynxInsightFaceCrop": LynxInsightFaceCrop,
|
|
}
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"LoadLynxResampler": "Load Lynx Resampler",
|
|
"LynxEncodeFaceIP": "Lynx Encode Face IP",
|
|
"DrawArcFaceLandmarks": "Draw ArcFace Landmarks",
|
|
"WanVideoAddLynxEmbeds": "WanVideo Add Lynx Embeds",
|
|
"LynxInsightFaceCrop": "Lynx InsightFace Crop",
|
|
}
|