1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00

Add imitation of SCAIL pose drawing to the existing NLF node

This only draws the pose with same colors, it's not meant as final solution, just for testing.
This commit is contained in:
kijai
2025-12-13 02:41:06 +02:00
parent 1f86cebdaa
commit 9652146763
3 changed files with 3529 additions and 35 deletions

View File

@@ -31,7 +31,7 @@ def p3d_to_p2d(point_3d, height, width): # point3d n*1024*3
def get_pose_images(smpl_data, offset):
pose_images = []
for data in smpl_data:
for data in smpl_data:
if isinstance(data, np.ndarray):
joints3d = data
else:
@@ -43,28 +43,33 @@ def get_pose_images(smpl_data, offset):
return pose_images
def get_control_conditions(poses, h, w):
video_transforms = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
def get_control_conditions(poses, h, w, stick_width=1.0, point_radius=2, style="original"):
control_images = []
for idx, pose in enumerate(poses):
canvas = np.zeros(shape=(h, w, 3), dtype=np.uint8)
try:
joints3d = p3d_to_p2d(pose, h, w)
canvas = draw_3d_points(
canvas,
joints3d[0],
stickwidth=int(h / 350),
)
if style == "original":
canvas = draw_3d_points(
canvas,
joints3d[0],
stickwidth=int(h / 350 * stick_width),
r=point_radius,
)
elif style == "scail":
canvas = draw_3d_points_scail(
canvas,
joints3d[0],
stickwidth=int(h / 350 * stick_width),
r=point_radius,
)
resized_canvas = cv2.resize(canvas, (w, h))
# Image.fromarray(resized_canvas).save(f'tmp/{idx}_pose.jpg')
control_images.append(resized_canvas)
except Exception as e:
print("wrong:", e)
except Exception:
control_images.append(Image.fromarray(canvas))
control_pixel_values = np.array(control_images)
control_pixel_values = torch.from_numpy(control_pixel_values).contiguous() / 255.
print("control_pixel_values.shape", control_pixel_values.shape)
#control_pixel_values = video_transforms(control_pixel_values)
return control_pixel_values
@@ -140,3 +145,69 @@ def draw_3d_points(canvas, points, stickwidth=2, r=2, draw_line=True):
cv2.fillConvexPoly(canvas, polygon, connection_colors[i%17])
return canvas
def draw_3d_points_scail(canvas, points, stickwidth=2, r=2, draw_line=True):
connetions = [
[15,12],[12, 16],[16, 18],[18, 20],[20, 22], # 0-4: Left arm chain
[12,17],[17,19],[19,21], # 5-7: Right arm chain
[21,23], # 8: Right hand
[12,1],[1,4],[4,7], # 9-11: Neck to left leg (hip, thigh, shin)
[12,2],[2,5],[5,8], # 12-14: Neck to right leg (hip, thigh, shin)
]
# Warm colors for right side, cool colors for left side
connection_colors = [
[180, 180, 180], # 0: [15,12] - L. clavicle (Bright Cyan)
[0, 200, 255], # 1: [12,16] - L. shoulder (Bright Cyan)
[0, 120, 255], # 2: [16,18] - L. upper arm (Bright Blue)
[0, 60, 255], # 3: [18,20] - L. forearm (Deep Blue)
[60, 0, 255], # 4: [20,22] - L. hand (Blue-Purple)
[255, 0, 0], # 5: [12,17] - R. clavicle (Bright Red)
[255, 100, 0], # 6: [17,19] - R. upper arm (Bright Orange)
[255, 180, 0], # 7: [19,21] - R. forearm (Golden Orange)
[255, 255, 0], # 8: [21,23] - R. hand (Bright Yellow)
[30, 27, 160], # 9: [12,1] - Neck to L. hip (purple-blue)
[73, 27, 177], # 10: [1,4] - L. thigh (purple)
[145, 27, 194], # 11: [4,7] - L. shin (magenta)
[200, 255, 100], # 12: [12,2] - Neck to R. hip (yellow)
[54, 201, 52], # 13: [2,5] - R. thigh (green)
[30, 176, 85], # 14: [5,8] - R. shin (green)
]
# draw line
if draw_line:
# Collect all joints that are part of connections
joints_in_use = set()
for connection in connetions:
joints_in_use.add(connection[0])
joints_in_use.add(connection[1])
for i in range(len(connetions)):
point1_idx, point2_idx = connetions[i][0:2]
point1 = points[point1_idx]
point2 = points[point2_idx]
x1, y1 = int(point1[0]), int(point1[1])
x2, y2 = int(point2[0]), int(point2[1])
cv2.line(canvas, (x1, y1), (x2, y2), connection_colors[i], stickwidth)
# draw points for joints that have connections
joints_in_use = set()
for connection in connetions:
joints_in_use.add(connection[0])
joints_in_use.add(connection[1])
for joint_idx in joints_in_use:
if joint_idx >= len(points):
continue
x, y = points[joint_idx][0:2]
x, y = int(x), int(y)
# Use the color from the first connection involving this joint
joint_color = [180, 180, 180] # default grey
for i, connection in enumerate(connetions):
if connection[0] == joint_idx or connection[1] == joint_idx:
joint_color = connection_colors[i]
break
cv2.circle(canvas, (x, y), r, joint_color, thickness=-1)
return canvas

View File

@@ -1,10 +1,7 @@
import os
import torch
import gc
from ..utils import log, dict_to_device
import numpy as np
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
import comfy.model_management as mm
from comfy.utils import load_torch_file
@@ -17,7 +14,6 @@ offload_device = mm.unet_offload_device()
local_model_path = os.path.join(folder_paths.models_dir, "nlf", "nlf_l_multi_0.3.2.torchscript")
from .motion4d import SMPL_VQVAE, VectorQuantizer, Encoder, Decoder
from .mtv import prepare_motion_embeddings
class DownloadAndLoadNLFModel:
@classmethod
@@ -38,7 +34,7 @@ class DownloadAndLoadNLFModel:
CATEGORY = "WanVideoWrapper"
def loadmodel(self, url):
if not os.path.exists(local_model_path):
log.info(f"Downloading NLF model to: {local_model_path}")
import requests
@@ -108,7 +104,7 @@ class LoadVQVAE:
frame_upsample_rate=[2.0, 2.0],
joint_upsample_rate=[1.0, 1.0]
)
vqvae = SMPL_VQVAE(motion_encoder, motion_decoder, motion_quant).to(device)
vqvae.load_state_dict(vae_sd, strict=True)
@@ -131,15 +127,6 @@ class MTVCrafterEncodePoses:
def encode(self, vqvae, poses):
# import pickle
# with open(os.path.join(script_directory, "data", "sampled_data.pkl"), 'rb') as f:
# data_list = pickle.load(f)
# if not isinstance(data_list, list):
# data_list = [data_list]
# print(data_list)
# smpl_poses = data_list[1]['pose']
global_mean = np.load(os.path.join(script_directory, "data", "mean.npy")) #global_mean.shape: (24, 3)
global_std = np.load(os.path.join(script_directory, "data", "std.npy"))
@@ -153,7 +140,7 @@ class MTVCrafterEncodePoses:
vqvae.to(device)
motion_tokens, vq_loss = vqvae(norm_poses.to(device), return_vq=True)
recon_motion = vqvae(norm_poses.to(device))[0][0].to(dtype=torch.float32).cpu().detach() * global_std + global_mean
vqvae.to(offload_device)
@@ -162,7 +149,7 @@ class MTVCrafterEncodePoses:
'global_mean': global_mean,
'global_std': global_std
}
return poses_dict, recon_motion
@@ -181,10 +168,13 @@ class NLFPredict:
CATEGORY = "WanVideoWrapper"
def predict(self, model, images):
prev_fuser_state = torch._C._jit_texpr_fuser_enabled()
model.to(device)
torch._C._jit_set_texpr_fuser_enabled(False) # removes warmup delay, may want to enable later
pred = model.detect_smpl_batched(images.permute(0, 3, 1, 2).to(device))
model.to(offload_device)
torch._C._jit_set_texpr_fuser_enabled(prev_fuser_state)
pred = dict_to_device(pred, offload_device)
@@ -197,7 +187,7 @@ class NLFPredict:
pose_results[key].append(pred[key])
else:
pose_results[key].append(None)
return (pose_results,)
class DrawNLFPoses:
@@ -208,21 +198,27 @@ class DrawNLFPoses:
"width": ("INT", {"default": 512}),
"height": ("INT", {"default": 512}),
},
}
"optional": {
"stick_width": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 1000.0, "step": 0.01, "tooltip": "Stick width multiplier"}),
"point_radius": ("INT", {"default": 5, "min": 1, "max": 10, "step": 1, "tooltip": "Point radius for drawing the pose"}),
"style": (["original", "scail"], {"default": "original", "tooltip": "style of the pose drawing"}),
}
}
RETURN_TYPES = ("IMAGE", )
RETURN_NAMES = ("image",)
FUNCTION = "predict"
CATEGORY = "WanVideoWrapper"
def predict(self, poses, width, height):
def predict(self, poses, width, height, stick_width=1.0, point_radius=2, style="original"):
from .draw_pose import get_control_conditions
print(type(poses))
if isinstance(poses, dict):
pose_input = poses['joints3d_nonparam'][0] if 'joints3d_nonparam' in poses else poses
else:
pose_input = poses
control_conditions = get_control_conditions(pose_input, height, width)
control_conditions = get_control_conditions(pose_input, height, width, stick_width=stick_width, point_radius=point_radius, style=style)
return (control_conditions,)

File diff suppressed because one or more lines are too long