You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
commit916fc0b1bcAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 15 17:30:37 2025 +0200 Update nodes.py commit63818324f5Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 15 17:30:26 2025 +0200 Refactor RoPE caching commitbb0c55da4dAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 15 01:59:16 2025 +0200 Update nodes_sampler.py commita0447d5553Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 15 01:28:09 2025 +0200 Fix non scale wfs commitfa761cc2f2Merge:ea1677b3aae54fAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 15 01:26:23 2025 +0200 Merge branch 'main' into SCAIL commitea1677bd4aAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sun Dec 14 19:41:43 2025 +0200 Handle torchscript issue better Some other custom nodes globally set torch._C._jit_set_profiling_executor(False) which breaks the NLF model commite3cfa64bd3Merge:ad7a0b93611341Author: kijai <40791699+kijai@users.noreply.github.com> Date: Sun Dec 14 16:49:04 2025 +0200 Merge branch 'main' into SCAIL commitad7a0b925dAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sun Dec 14 16:10:34 2025 +0200 Fix possible uni3c issue commit74d97fa4bbAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sun Dec 14 15:58:42 2025 +0200 Match Uni3C temporal dim commit056d8ad96fAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sun Dec 14 14:47:58 2025 +0200 Add warning for potential other overrides on torch.jit.script commitf6dff002ffAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sun Dec 14 14:19:33 2025 +0200 Add option to warmup the NLF model on load and fix it's offloading commita19107501dAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sun Dec 14 13:45:20 2025 +0200 Add error to indicate ComfyUI-RMBG currently breaks the NLF model commite2cfa486e4Author: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 13 23:29:49 2025 +0200 Cleanup unnecessary code commit462b61855fAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 13 18:05:10 2025 +0200 context windows commite57d4baeebAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 13 16:55:23 2025 +0200 Start/end percentages and strength commit3e507ae322Merge:1e5c7cb0fa5383Author: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 13 16:09:16 2025 +0200 Merge branch 'main' into SCAIL commit1e5c7cb211Author: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 13 15:45:39 2025 +0200 Update nodes.py commit98f8e56bcaMerge:965214678e3e18Author: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 13 15:42:44 2025 +0200 Merge branch 'main' into SCAIL commit9652146763Author: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 13 02:41:06 2025 +0200 Add imitation of SCAIL pose drawing to the existing NLF node This only draws the pose with same colors, it's not meant as final solution, just for testing. commit1f86cebdaaAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 13 01:11:56 2025 +0200 test pose inputs commitb348b21dbeAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Fri Dec 12 20:10:48 2025 +0200 Init
208 lines
8.5 KiB
Python
208 lines
8.5 KiB
Python
import json
|
|
import torch
|
|
import torchvision.transforms.functional as TF
|
|
from ..utils import log
|
|
from .trajectory import create_pos_feature_map, draw_tracks_on_video, replace_feature
|
|
import os
|
|
from comfy import model_management as mm
|
|
device = mm.get_torch_device()
|
|
script_directory = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
VAE_STRIDE = (4, 8, 8) # t, h, w
|
|
|
|
class WanVideoWanDrawWanMoveTracks:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"images": ("IMAGE",),
|
|
"tracks": ("TRACKS",),
|
|
},
|
|
"optional": {
|
|
"line_resolution": ("INT", {"default": 24, "min": 4, "max": 64, "step": 1, "tooltip": "Number of points to use for each line segment"}),
|
|
"circle_size": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1, "tooltip": "Size of the circle to draw for each track point"}),
|
|
"opacity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Opacity of the circle to draw for each track point"}),
|
|
"line_width": ("INT", {"default": 14, "min": 1, "max": 50, "step": 1, "tooltip": "Width of the line to draw for each track"}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
RETURN_NAMES = ("image",)
|
|
FUNCTION = "execute"
|
|
CATEGORY = "WanVideoWrapper"
|
|
|
|
def execute(self, images, tracks, line_resolution=24, circle_size=10, opacity=0.5, line_width=14):
|
|
if tracks is None or "track_path" not in tracks:
|
|
log.warning("WanVideoWanDrawWanMoveTracks: No tracks provided.")
|
|
return (images.float().cpu(), )
|
|
track = tracks["track_path"].unsqueeze(0)
|
|
track_visibility = tracks["track_visibility"].unsqueeze(0)
|
|
images_in = images * 255.0
|
|
if images_in.shape[0] != track.shape[1]:
|
|
repeat_count = track.shape[1] // images.shape[0]
|
|
images_in = images_in.repeat(repeat_count, 1, 1, 1)
|
|
track_video = draw_tracks_on_video(images_in, track, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
|
|
track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1)
|
|
|
|
return (track_video.float().cpu(), )
|
|
|
|
|
|
class WanVideoAddWanMoveTracks:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"image_embeds": ("WANVIDIMAGE_EMBEDS",),
|
|
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Strength of the reference embedding"}),
|
|
},
|
|
"optional": {
|
|
"track_mask": ("MASK",),
|
|
"track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}),
|
|
"tracks": ("TRACKS", {"tooltip": "Alternatively use Comfy Tracks dictionary"}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "TRACKS")
|
|
RETURN_NAMES = ("image_embeds", "tracks")
|
|
FUNCTION = "add"
|
|
CATEGORY = "WanVideoWrapper"
|
|
|
|
def add(self, image_embeds, track_coords=None, tracks=None, strength=1.0, track_mask=None):
|
|
updated = dict(image_embeds)
|
|
|
|
track_visibility = None
|
|
|
|
target_shape = image_embeds.get("target_shape")
|
|
if target_shape is not None:
|
|
height = target_shape[2] * VAE_STRIDE[1]
|
|
width = target_shape[3] * VAE_STRIDE[2]
|
|
else:
|
|
height = image_embeds["lat_h"] * VAE_STRIDE[1]
|
|
width = image_embeds["lat_w"] * VAE_STRIDE[2]
|
|
num_frames = image_embeds["num_frames"]
|
|
|
|
if track_coords is not None:
|
|
tracks_data = parse_json_tracks(track_coords)
|
|
track_list = [
|
|
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
|
|
for frame in range(len(tracks_data[0]))
|
|
]
|
|
track = torch.tensor(track_list, dtype=torch.float32, device=device) # shape: (frames, num_tracks, 2)
|
|
elif tracks is not None and "track_path" in tracks:
|
|
track = tracks["track_path"]
|
|
if track_mask is None:
|
|
track_visibility = tracks.get("track_visibility", None)
|
|
track = track[:num_frames]
|
|
|
|
num_tracks = track.shape[-2]
|
|
if track_visibility is None:
|
|
if track_mask is None:
|
|
track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device)
|
|
else:
|
|
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
|
|
feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device)
|
|
|
|
updated.setdefault("wanmove_embeds", {})
|
|
updated["wanmove_embeds"]["track_pos"] = track_pos
|
|
updated["wanmove_embeds"]["strength"] = strength
|
|
|
|
tracks_dict = {
|
|
"track_path": track,
|
|
"track_visibility": track_visibility,
|
|
}
|
|
|
|
return (updated, tracks_dict,)
|
|
|
|
|
|
def parse_json_tracks(tracks):
|
|
tracks_data = []
|
|
try:
|
|
# If tracks is a string, try to parse it as JSON
|
|
if isinstance(tracks, str):
|
|
parsed = json.loads(tracks.replace("'", '"'))
|
|
tracks_data.extend(parsed)
|
|
else:
|
|
# If tracks is a list of strings, parse each one
|
|
for track_str in tracks:
|
|
parsed = json.loads(track_str.replace("'", '"'))
|
|
tracks_data.append(parsed)
|
|
|
|
# Check if we have a single track (dict with x,y) or a list of tracks
|
|
if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]:
|
|
# Single track detected, wrap it in a list
|
|
tracks_data = [tracks_data]
|
|
elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]:
|
|
# Already a list of tracks, nothing to do
|
|
pass
|
|
else:
|
|
# Unexpected format
|
|
log.warning(f"Warning: Unexpected track format: {type(tracks_data[0])}")
|
|
|
|
except json.JSONDecodeError as e:
|
|
log.warning(f"Error parsing tracks JSON: {e}")
|
|
tracks_data = []
|
|
|
|
return tracks_data
|
|
|
|
import node_helpers
|
|
|
|
class WanMove_native:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": {
|
|
"positive": ("CONDITIONING",),
|
|
"track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}),
|
|
},
|
|
"optional": {
|
|
"track_mask": ("MASK",),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("CONDITIONING", "TRACKS")
|
|
RETURN_NAMES = ("positive", "tracks")
|
|
FUNCTION = "patchcond"
|
|
CATEGORY = "WanVideoWrapper"
|
|
DEPRECATED = True
|
|
|
|
def patchcond(self, positive, track_coords, track_mask=None):
|
|
|
|
concat_latent_image = positive[0][1]["concat_latent_image"]
|
|
B, C, T, H, W = concat_latent_image.shape
|
|
num_frames = (T-1) * 4 + 1
|
|
width = W * 8
|
|
height = H * 8
|
|
|
|
tracks_data = parse_json_tracks(track_coords)
|
|
track_list = [
|
|
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
|
|
for frame in range(len(tracks_data[0]))
|
|
]
|
|
track = torch.tensor(track_list, dtype=torch.float32, device=device) # shape: (frames, num_tracks, 2)
|
|
track = track[:num_frames]
|
|
|
|
num_tracks = track.shape[-2]
|
|
if track_mask is None:
|
|
track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device)
|
|
else:
|
|
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
|
|
|
|
feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device)
|
|
wanmove_cond = replace_feature(concat_latent_image, track_pos.unsqueeze(0))
|
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": wanmove_cond})
|
|
|
|
tracks_dict = {
|
|
"track_path": track,
|
|
"track_visibility": track_visibility,
|
|
}
|
|
return (positive, tracks_dict)
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"WanVideoAddWanMoveTracks": WanVideoAddWanMoveTracks,
|
|
"WanVideoWanDrawWanMoveTracks": WanVideoWanDrawWanMoveTracks,
|
|
"WanMove_native": WanMove_native,
|
|
}
|
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"WanVideoAddWanMoveTracks": "WanVideo Add WanMove Tracks",
|
|
"WanVideoWanDrawWanMoveTracks": "WanVideo Draw WanMove Tracks",
|
|
"WanMove_native": "WanMove Native",
|
|
}
|