1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00
Files
kijai a9e21f164c Squashed commit of the following:
commit 916fc0b1bc
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 15 17:30:37 2025 +0200

    Update nodes.py

commit 63818324f5
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 15 17:30:26 2025 +0200

    Refactor RoPE caching

commit bb0c55da4d
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 15 01:59:16 2025 +0200

    Update nodes_sampler.py

commit a0447d5553
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 15 01:28:09 2025 +0200

    Fix non scale wfs

commit fa761cc2f2
Merge: ea1677b 3aae54f
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 15 01:26:23 2025 +0200

    Merge branch 'main' into SCAIL

commit ea1677bd4a
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sun Dec 14 19:41:43 2025 +0200

    Handle torchscript issue better

    Some other custom nodes globally set torch._C._jit_set_profiling_executor(False) which breaks the NLF model

commit e3cfa64bd3
Merge: ad7a0b9 3611341
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sun Dec 14 16:49:04 2025 +0200

    Merge branch 'main' into SCAIL

commit ad7a0b925d
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sun Dec 14 16:10:34 2025 +0200

    Fix possible uni3c issue

commit 74d97fa4bb
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sun Dec 14 15:58:42 2025 +0200

    Match Uni3C temporal dim

commit 056d8ad96f
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sun Dec 14 14:47:58 2025 +0200

    Add warning for potential other overrides on torch.jit.script

commit f6dff002ff
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sun Dec 14 14:19:33 2025 +0200

    Add option to warmup the NLF model on load and fix it's offloading

commit a19107501d
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sun Dec 14 13:45:20 2025 +0200

    Add error to indicate ComfyUI-RMBG currently breaks the NLF model

commit e2cfa486e4
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 13 23:29:49 2025 +0200

    Cleanup unnecessary code

commit 462b61855f
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 13 18:05:10 2025 +0200

    context windows

commit e57d4baeeb
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 13 16:55:23 2025 +0200

    Start/end percentages and strength

commit 3e507ae322
Merge: 1e5c7cb 0fa5383
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 13 16:09:16 2025 +0200

    Merge branch 'main' into SCAIL

commit 1e5c7cb211
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 13 15:45:39 2025 +0200

    Update nodes.py

commit 98f8e56bca
Merge: 9652146 78e3e18
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 13 15:42:44 2025 +0200

    Merge branch 'main' into SCAIL

commit 9652146763
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 13 02:41:06 2025 +0200

    Add imitation of SCAIL pose drawing to the existing NLF node

    This only draws the pose with same colors, it's not meant as final solution, just for testing.

commit 1f86cebdaa
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 13 01:11:56 2025 +0200

    test pose inputs

commit b348b21dbe
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Fri Dec 12 20:10:48 2025 +0200

    Init
2025-12-15 17:31:01 +02:00

208 lines
8.5 KiB
Python

import json
import torch
import torchvision.transforms.functional as TF
from ..utils import log
from .trajectory import create_pos_feature_map, draw_tracks_on_video, replace_feature
import os
from comfy import model_management as mm
device = mm.get_torch_device()
script_directory = os.path.dirname(os.path.abspath(__file__))
VAE_STRIDE = (4, 8, 8) # t, h, w
class WanVideoWanDrawWanMoveTracks:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"images": ("IMAGE",),
"tracks": ("TRACKS",),
},
"optional": {
"line_resolution": ("INT", {"default": 24, "min": 4, "max": 64, "step": 1, "tooltip": "Number of points to use for each line segment"}),
"circle_size": ("INT", {"default": 10, "min": 1, "max": 20, "step": 1, "tooltip": "Size of the circle to draw for each track point"}),
"opacity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Opacity of the circle to draw for each track point"}),
"line_width": ("INT", {"default": 14, "min": 1, "max": 50, "step": 1, "tooltip": "Width of the line to draw for each track"}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "execute"
CATEGORY = "WanVideoWrapper"
def execute(self, images, tracks, line_resolution=24, circle_size=10, opacity=0.5, line_width=14):
if tracks is None or "track_path" not in tracks:
log.warning("WanVideoWanDrawWanMoveTracks: No tracks provided.")
return (images.float().cpu(), )
track = tracks["track_path"].unsqueeze(0)
track_visibility = tracks["track_visibility"].unsqueeze(0)
images_in = images * 255.0
if images_in.shape[0] != track.shape[1]:
repeat_count = track.shape[1] // images.shape[0]
images_in = images_in.repeat(repeat_count, 1, 1, 1)
track_video = draw_tracks_on_video(images_in, track, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1)
return (track_video.float().cpu(), )
class WanVideoAddWanMoveTracks:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image_embeds": ("WANVIDIMAGE_EMBEDS",),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "tooltip": "Strength of the reference embedding"}),
},
"optional": {
"track_mask": ("MASK",),
"track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}),
"tracks": ("TRACKS", {"tooltip": "Alternatively use Comfy Tracks dictionary"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS", "TRACKS")
RETURN_NAMES = ("image_embeds", "tracks")
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
def add(self, image_embeds, track_coords=None, tracks=None, strength=1.0, track_mask=None):
updated = dict(image_embeds)
track_visibility = None
target_shape = image_embeds.get("target_shape")
if target_shape is not None:
height = target_shape[2] * VAE_STRIDE[1]
width = target_shape[3] * VAE_STRIDE[2]
else:
height = image_embeds["lat_h"] * VAE_STRIDE[1]
width = image_embeds["lat_w"] * VAE_STRIDE[2]
num_frames = image_embeds["num_frames"]
if track_coords is not None:
tracks_data = parse_json_tracks(track_coords)
track_list = [
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
for frame in range(len(tracks_data[0]))
]
track = torch.tensor(track_list, dtype=torch.float32, device=device) # shape: (frames, num_tracks, 2)
elif tracks is not None and "track_path" in tracks:
track = tracks["track_path"]
if track_mask is None:
track_visibility = tracks.get("track_visibility", None)
track = track[:num_frames]
num_tracks = track.shape[-2]
if track_visibility is None:
if track_mask is None:
track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device)
updated.setdefault("wanmove_embeds", {})
updated["wanmove_embeds"]["track_pos"] = track_pos
updated["wanmove_embeds"]["strength"] = strength
tracks_dict = {
"track_path": track,
"track_visibility": track_visibility,
}
return (updated, tracks_dict,)
def parse_json_tracks(tracks):
tracks_data = []
try:
# If tracks is a string, try to parse it as JSON
if isinstance(tracks, str):
parsed = json.loads(tracks.replace("'", '"'))
tracks_data.extend(parsed)
else:
# If tracks is a list of strings, parse each one
for track_str in tracks:
parsed = json.loads(track_str.replace("'", '"'))
tracks_data.append(parsed)
# Check if we have a single track (dict with x,y) or a list of tracks
if tracks_data and isinstance(tracks_data[0], dict) and 'x' in tracks_data[0]:
# Single track detected, wrap it in a list
tracks_data = [tracks_data]
elif tracks_data and isinstance(tracks_data[0], list) and tracks_data[0] and isinstance(tracks_data[0][0], dict) and 'x' in tracks_data[0][0]:
# Already a list of tracks, nothing to do
pass
else:
# Unexpected format
log.warning(f"Warning: Unexpected track format: {type(tracks_data[0])}")
except json.JSONDecodeError as e:
log.warning(f"Error parsing tracks JSON: {e}")
tracks_data = []
return tracks_data
import node_helpers
class WanMove_native:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"positive": ("CONDITIONING",),
"track_coords": ("STRING", {"forceInput": True, "tooltip": "JSON string or list of JSON strings representing the tracks"}),
},
"optional": {
"track_mask": ("MASK",),
}
}
RETURN_TYPES = ("CONDITIONING", "TRACKS")
RETURN_NAMES = ("positive", "tracks")
FUNCTION = "patchcond"
CATEGORY = "WanVideoWrapper"
DEPRECATED = True
def patchcond(self, positive, track_coords, track_mask=None):
concat_latent_image = positive[0][1]["concat_latent_image"]
B, C, T, H, W = concat_latent_image.shape
num_frames = (T-1) * 4 + 1
width = W * 8
height = H * 8
tracks_data = parse_json_tracks(track_coords)
track_list = [
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
for frame in range(len(tracks_data[0]))
]
track = torch.tensor(track_list, dtype=torch.float32, device=device) # shape: (frames, num_tracks, 2)
track = track[:num_frames]
num_tracks = track.shape[-2]
if track_mask is None:
track_visibility = torch.ones((num_frames, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
feature_map, track_pos = create_pos_feature_map(track, track_visibility, VAE_STRIDE, height, width, 16, track_num=num_tracks, device=device)
wanmove_cond = replace_feature(concat_latent_image, track_pos.unsqueeze(0))
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": wanmove_cond})
tracks_dict = {
"track_path": track,
"track_visibility": track_visibility,
}
return (positive, tracks_dict)
NODE_CLASS_MAPPINGS = {
"WanVideoAddWanMoveTracks": WanVideoAddWanMoveTracks,
"WanVideoWanDrawWanMoveTracks": WanVideoWanDrawWanMoveTracks,
"WanMove_native": WanMove_native,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoAddWanMoveTracks": "WanVideo Add WanMove Tracks",
"WanVideoWanDrawWanMoveTracks": "WanVideo Draw WanMove Tracks",
"WanMove_native": "WanMove Native",
}