You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
Merge branch 'main' into bindweave
This commit is contained in:
1938
example_workflows/wanvideo_1_3B_UniLumos_relight_example_01.json
Normal file
1938
example_workflows/wanvideo_1_3B_UniLumos_relight_example_01.json
Normal file
File diff suppressed because it is too large
Load Diff
42
nodes.py
42
nodes.py
@@ -1246,6 +1246,46 @@ class WanVideoAnimateEmbeds:
|
||||
}
|
||||
|
||||
return (image_embeds,)
|
||||
|
||||
# region UniLumos
|
||||
class WanVideoUniLumosEmbeds:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"width": ("INT", {"default": 832, "min": 64, "max": 8096, "step": 8, "tooltip": "Width of the image to encode"}),
|
||||
"height": ("INT", {"default": 480, "min": 64, "max": 8096, "step": 8, "tooltip": "Height of the image to encode"}),
|
||||
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
|
||||
},
|
||||
"optional": {
|
||||
"foreground_latents": ("LATENT", {"tooltip": "Video foreground latents"}),
|
||||
"background_latents": ("LATENT", {"tooltip": "Video background latents"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", )
|
||||
RETURN_NAMES = ("image_embeds",)
|
||||
FUNCTION = "process"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
|
||||
def process(self, num_frames, width, height, foreground_latents=None, background_latents=None):
|
||||
target_shape = (16, (num_frames - 1) // VAE_STRIDE[0] + 1,
|
||||
height // VAE_STRIDE[1],
|
||||
width // VAE_STRIDE[2])
|
||||
|
||||
embeds = {
|
||||
"target_shape": target_shape,
|
||||
"num_frames": num_frames,
|
||||
}
|
||||
if foreground_latents is not None:
|
||||
embeds["foreground_latents"] = foreground_latents["samples"][0]
|
||||
else:
|
||||
embeds["foreground_latents"] = torch.zeros(target_shape[0], target_shape[1], target_shape[2], target_shape[3], device=torch.device("cpu"), dtype=torch.float32)
|
||||
if background_latents is not None:
|
||||
embeds["background_latents"] = background_latents["samples"][0]
|
||||
else:
|
||||
embeds["background_latents"] = torch.zeros(target_shape[0], target_shape[1], target_shape[2], target_shape[3], device=torch.device("cpu"), dtype=torch.float32)
|
||||
|
||||
return (embeds,)
|
||||
|
||||
class WanVideoEmptyEmbeds:
|
||||
@classmethod
|
||||
@@ -2296,6 +2336,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoSchedulerSA_ODE": WanVideoSchedulerSA_ODE,
|
||||
"WanVideoAddBindweaveEmbeds": WanVideoAddBindweaveEmbeds,
|
||||
"TextImageEncodeQwenVL": TextImageEncodeQwenVL,
|
||||
"WanVideoUniLumosEmbeds": WanVideoUniLumosEmbeds,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
@@ -2336,4 +2377,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoAddLucyEditLatents": "WanVideo Add LucyEdit Latents",
|
||||
"WanVideoSchedulerSA_ODE": "WanVideo Scheduler SA-ODE",
|
||||
"WanVideoAddBindweaveEmbeds": "WanVideo Add Bindweave Embeds",
|
||||
"WanVideoUniLumosEmbeds": "WanVideo UniLumos Embeds",
|
||||
}
|
||||
|
||||
@@ -1088,6 +1088,14 @@ class WanVideoModelLoader:
|
||||
sd, reader = load_gguf(model_path)
|
||||
gguf_reader.append(reader)
|
||||
|
||||
# Ovi
|
||||
extra_audio_model = False
|
||||
if any(key.startswith("video_model.") for key in sd.keys()):
|
||||
sd = {key.replace("video_model.", "", 1).replace("modulation.modulation", "modulation"): value for key, value in sd.items()}
|
||||
if any(key.startswith("audio_model.") for key in sd.keys()) and any(key.startswith("blocks.") for key in sd.keys()):
|
||||
extra_audio_model = True
|
||||
|
||||
|
||||
is_wananimate = "pose_patch_embedding.weight" in sd
|
||||
# rename WanAnimate face fuser block keys to insert into main blocks instead
|
||||
if is_wananimate:
|
||||
@@ -1140,7 +1148,6 @@ class WanVideoModelLoader:
|
||||
raise ValueError("You are attempting to load a VACE module as a WanVideo model, instead you should use the vace_model input and matching T2V base model")
|
||||
|
||||
# currently this can be VACE, MTV-Crafter, Lynx or Ovi-audio weights
|
||||
extra_audio_model = False
|
||||
if extra_model is not None:
|
||||
for _model in extra_model:
|
||||
print("Loading extra model: ", _model["path"])
|
||||
|
||||
@@ -1141,6 +1141,16 @@ class WanVideoSampler:
|
||||
lynx_embeds["ref_buffer_uncond"] = lynx_ref_buffer_uncond if not math.isclose(cfg[0], 1.0) else None
|
||||
mm.soft_empty_cache()
|
||||
|
||||
# UniLumos
|
||||
foreground_latents = image_embeds.get("foreground_latents", None)
|
||||
if foreground_latents is not None:
|
||||
log.info(f"UniLumos foreground latent input shape: {foreground_latents.shape}")
|
||||
foreground_latents = foreground_latents.to(device, dtype)
|
||||
background_latents = image_embeds.get("background_latents", None)
|
||||
if background_latents is not None:
|
||||
log.info(f"UniLumos background latent input shape: {background_latents.shape}")
|
||||
background_latents = background_latents.to(device, dtype)
|
||||
|
||||
#region model pred
|
||||
def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None,
|
||||
control_latents=None, vace_data=None, unianim_data=None, audio_proj=None, control_camera_latents=None,
|
||||
@@ -1361,6 +1371,9 @@ class WanVideoSampler:
|
||||
else:
|
||||
self.noise_front_pad_num = 0
|
||||
|
||||
if background_latents is not None or foreground_latents is not None:
|
||||
z = torch.cat([z, foreground_latents.to(z), background_latents.to(z)], dim=0)
|
||||
|
||||
base_params = {
|
||||
'x': [z], # latent
|
||||
'y': [image_cond_input] if image_cond_input is not None else None, # image cond
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from comfy.utils import common_upscale
|
||||
from comfy import model_management
|
||||
from tqdm import tqdm
|
||||
from .utils import log
|
||||
from einops import rearrange
|
||||
|
||||
@@ -12,6 +14,9 @@ except:
|
||||
VAE_STRIDE = (4, 8, 8)
|
||||
PATCH_SIZE = (1, 2, 2)
|
||||
|
||||
main_device = model_management.get_torch_device()
|
||||
offload_device = model_management.unet_offload_device()
|
||||
|
||||
class WanVideoImageResizeToClosest:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@@ -660,6 +665,96 @@ class FaceMaskFromPoseKeypoints:
|
||||
cv2.fillPoly(canvas, pts=[outer_contour], color=part_color)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
class DrawGaussianNoiseOnImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"image": ("IMAGE", ),
|
||||
"mask": ("MASK", ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["cpu", "gpu"], {"default": "cpu", "tooltip": "Device to use for processing"}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", )
|
||||
RETURN_NAMES = ("images",)
|
||||
FUNCTION = "apply"
|
||||
CATEGORY = "KJNodes/masking"
|
||||
DESCRIPTION = "Fills the background (masked area) with Gaussian noise sampled using the mean and variance of the subject (unmasked) region."
|
||||
|
||||
def apply(self, image, mask, device="cpu", seed=0):
|
||||
B, H, W, C = image.shape
|
||||
BM, HM, WM = mask.shape
|
||||
|
||||
processing_device = main_device if device == "gpu" else torch.device("cpu")
|
||||
|
||||
in_masks = mask.clone().to(processing_device)
|
||||
in_images = image.clone().to(processing_device)
|
||||
|
||||
# Resize mask to match image dimensions
|
||||
if HM != H or WM != W:
|
||||
in_masks = F.interpolate(mask.unsqueeze(1), size=(H, W), mode='nearest-exact').squeeze(1)
|
||||
|
||||
# Match batch sizes
|
||||
if B > BM:
|
||||
in_masks = in_masks.repeat((B + BM - 1) // BM, 1, 1)[:B]
|
||||
elif BM > B:
|
||||
in_masks = in_masks[:B]
|
||||
|
||||
output_images = []
|
||||
|
||||
# Set random seed for reproducibility
|
||||
generator = torch.Generator(device=processing_device).manual_seed(seed)
|
||||
|
||||
for i in tqdm(range(B), desc="DrawGaussianNoiseOnImage batch"):
|
||||
curr_mask = in_masks[i]
|
||||
img_idx = min(i, B - 1)
|
||||
curr_image = in_images[img_idx]
|
||||
|
||||
# Expand mask to 3 channels
|
||||
mask_expanded = curr_mask.unsqueeze(-1).expand(-1, -1, 3)
|
||||
|
||||
# Calculate mean and std per channel from the subject region (where mask is 1)
|
||||
subject_mask = mask_expanded > 0.5
|
||||
|
||||
# Initialize noise tensor
|
||||
noise = torch.zeros_like(curr_image)
|
||||
|
||||
for c in range(C):
|
||||
channel = curr_image[:, :, c]
|
||||
channel_mask = subject_mask[:, :, c]
|
||||
|
||||
if channel_mask.sum() > 0:
|
||||
# Get subject pixels
|
||||
subject_pixels = channel[channel_mask]
|
||||
|
||||
# Calculate statistics
|
||||
mean = subject_pixels.mean()
|
||||
std = subject_pixels.std()
|
||||
|
||||
# Generate Gaussian noise for this channel
|
||||
noise[:, :, c] = torch.normal(mean=mean.item(), std=std.item(),
|
||||
size=(H, W), generator=generator,
|
||||
device=processing_device)
|
||||
|
||||
# Clamp noise to valid range
|
||||
noise = torch.clamp(noise, 0.0, 1.0)
|
||||
|
||||
# Apply: keep subject, fill background with noise
|
||||
masked_image = curr_image * mask_expanded + noise * (1 - mask_expanded)
|
||||
output_images.append(masked_image)
|
||||
|
||||
# If no masks were processed, return empty tensor
|
||||
if not output_images:
|
||||
return (torch.zeros((0, H, W, 3), dtype=image.dtype),)
|
||||
|
||||
out_rgb = torch.stack(output_images, dim=0).cpu()
|
||||
|
||||
return (out_rgb, )
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoImageResizeToClosest": WanVideoImageResizeToClosest,
|
||||
@@ -673,6 +768,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"NormalizeAudioLoudness": NormalizeAudioLoudness,
|
||||
"WanVideoPassImagesFromSamples": WanVideoPassImagesFromSamples,
|
||||
"FaceMaskFromPoseKeypoints": FaceMaskFromPoseKeypoints,
|
||||
"DrawGaussianNoiseOnImage": DrawGaussianNoiseOnImage,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoImageResizeToClosest": "WanVideo Image Resize To Closest",
|
||||
@@ -686,4 +782,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"NormalizeAudioLoudness": "Normalize Audio Loudness",
|
||||
"WanVideoPassImagesFromSamples": "WanVideo Pass Images From Samples",
|
||||
"FaceMaskFromPoseKeypoints": "Face Mask From Pose Keypoints",
|
||||
"DrawGaussianNoiseOnImage": "Draw Gaussian Noise On Image",
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "ComfyUI-WanVideoWrapper"
|
||||
description = "ComfyUI wrapper nodes for WanVideo"
|
||||
version = "1.3.8"
|
||||
version = "1.3.9"
|
||||
license = {file = "LICENSE"}
|
||||
dependencies = ["accelerate >= 1.2.1", "diffusers >= 0.33.0", "peft >= 0.17.0", "ftfy", "gguf >= 0.17.1", "pyloudnorm"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user