You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
Add imitation of SCAIL pose drawing to the existing NLF node
This only draws the pose with same colors, it's not meant as final solution, just for testing.
This commit is contained in:
@@ -31,7 +31,7 @@ def p3d_to_p2d(point_3d, height, width): # point3d n*1024*3
|
||||
|
||||
def get_pose_images(smpl_data, offset):
|
||||
pose_images = []
|
||||
for data in smpl_data:
|
||||
for data in smpl_data:
|
||||
if isinstance(data, np.ndarray):
|
||||
joints3d = data
|
||||
else:
|
||||
@@ -43,28 +43,33 @@ def get_pose_images(smpl_data, offset):
|
||||
return pose_images
|
||||
|
||||
|
||||
def get_control_conditions(poses, h, w):
|
||||
video_transforms = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
||||
def get_control_conditions(poses, h, w, stick_width=1.0, point_radius=2, style="original"):
|
||||
control_images = []
|
||||
for idx, pose in enumerate(poses):
|
||||
canvas = np.zeros(shape=(h, w, 3), dtype=np.uint8)
|
||||
try:
|
||||
joints3d = p3d_to_p2d(pose, h, w)
|
||||
canvas = draw_3d_points(
|
||||
canvas,
|
||||
joints3d[0],
|
||||
stickwidth=int(h / 350),
|
||||
)
|
||||
if style == "original":
|
||||
canvas = draw_3d_points(
|
||||
canvas,
|
||||
joints3d[0],
|
||||
stickwidth=int(h / 350 * stick_width),
|
||||
r=point_radius,
|
||||
)
|
||||
elif style == "scail":
|
||||
canvas = draw_3d_points_scail(
|
||||
canvas,
|
||||
joints3d[0],
|
||||
stickwidth=int(h / 350 * stick_width),
|
||||
r=point_radius,
|
||||
)
|
||||
resized_canvas = cv2.resize(canvas, (w, h))
|
||||
# Image.fromarray(resized_canvas).save(f'tmp/{idx}_pose.jpg')
|
||||
control_images.append(resized_canvas)
|
||||
except Exception as e:
|
||||
print("wrong:", e)
|
||||
except Exception:
|
||||
control_images.append(Image.fromarray(canvas))
|
||||
control_pixel_values = np.array(control_images)
|
||||
control_pixel_values = torch.from_numpy(control_pixel_values).contiguous() / 255.
|
||||
print("control_pixel_values.shape", control_pixel_values.shape)
|
||||
#control_pixel_values = video_transforms(control_pixel_values)
|
||||
return control_pixel_values
|
||||
|
||||
|
||||
@@ -140,3 +145,69 @@ def draw_3d_points(canvas, points, stickwidth=2, r=2, draw_line=True):
|
||||
cv2.fillConvexPoly(canvas, polygon, connection_colors[i%17])
|
||||
|
||||
return canvas
|
||||
|
||||
def draw_3d_points_scail(canvas, points, stickwidth=2, r=2, draw_line=True):
|
||||
|
||||
connetions = [
|
||||
[15,12],[12, 16],[16, 18],[18, 20],[20, 22], # 0-4: Left arm chain
|
||||
[12,17],[17,19],[19,21], # 5-7: Right arm chain
|
||||
[21,23], # 8: Right hand
|
||||
[12,1],[1,4],[4,7], # 9-11: Neck to left leg (hip, thigh, shin)
|
||||
[12,2],[2,5],[5,8], # 12-14: Neck to right leg (hip, thigh, shin)
|
||||
]
|
||||
|
||||
# Warm colors for right side, cool colors for left side
|
||||
connection_colors = [
|
||||
[180, 180, 180], # 0: [15,12] - L. clavicle (Bright Cyan)
|
||||
[0, 200, 255], # 1: [12,16] - L. shoulder (Bright Cyan)
|
||||
[0, 120, 255], # 2: [16,18] - L. upper arm (Bright Blue)
|
||||
[0, 60, 255], # 3: [18,20] - L. forearm (Deep Blue)
|
||||
[60, 0, 255], # 4: [20,22] - L. hand (Blue-Purple)
|
||||
[255, 0, 0], # 5: [12,17] - R. clavicle (Bright Red)
|
||||
[255, 100, 0], # 6: [17,19] - R. upper arm (Bright Orange)
|
||||
[255, 180, 0], # 7: [19,21] - R. forearm (Golden Orange)
|
||||
[255, 255, 0], # 8: [21,23] - R. hand (Bright Yellow)
|
||||
[30, 27, 160], # 9: [12,1] - Neck to L. hip (purple-blue)
|
||||
[73, 27, 177], # 10: [1,4] - L. thigh (purple)
|
||||
[145, 27, 194], # 11: [4,7] - L. shin (magenta)
|
||||
[200, 255, 100], # 12: [12,2] - Neck to R. hip (yellow)
|
||||
[54, 201, 52], # 13: [2,5] - R. thigh (green)
|
||||
[30, 176, 85], # 14: [5,8] - R. shin (green)
|
||||
]
|
||||
|
||||
# draw line
|
||||
if draw_line:
|
||||
# Collect all joints that are part of connections
|
||||
joints_in_use = set()
|
||||
for connection in connetions:
|
||||
joints_in_use.add(connection[0])
|
||||
joints_in_use.add(connection[1])
|
||||
|
||||
for i in range(len(connetions)):
|
||||
point1_idx, point2_idx = connetions[i][0:2]
|
||||
point1 = points[point1_idx]
|
||||
point2 = points[point2_idx]
|
||||
x1, y1 = int(point1[0]), int(point1[1])
|
||||
x2, y2 = int(point2[0]), int(point2[1])
|
||||
cv2.line(canvas, (x1, y1), (x2, y2), connection_colors[i], stickwidth)
|
||||
|
||||
# draw points for joints that have connections
|
||||
joints_in_use = set()
|
||||
for connection in connetions:
|
||||
joints_in_use.add(connection[0])
|
||||
joints_in_use.add(connection[1])
|
||||
|
||||
for joint_idx in joints_in_use:
|
||||
if joint_idx >= len(points):
|
||||
continue
|
||||
x, y = points[joint_idx][0:2]
|
||||
x, y = int(x), int(y)
|
||||
# Use the color from the first connection involving this joint
|
||||
joint_color = [180, 180, 180] # default grey
|
||||
for i, connection in enumerate(connetions):
|
||||
if connection[0] == joint_idx or connection[1] == joint_idx:
|
||||
joint_color = connection_colors[i]
|
||||
break
|
||||
cv2.circle(canvas, (x, y), r, joint_color, thickness=-1)
|
||||
|
||||
return canvas
|
||||
|
||||
42
MTV/nodes.py
42
MTV/nodes.py
@@ -1,10 +1,7 @@
|
||||
import os
|
||||
import torch
|
||||
import gc
|
||||
from ..utils import log, dict_to_device
|
||||
import numpy as np
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
import comfy.model_management as mm
|
||||
from comfy.utils import load_torch_file
|
||||
@@ -17,7 +14,6 @@ offload_device = mm.unet_offload_device()
|
||||
local_model_path = os.path.join(folder_paths.models_dir, "nlf", "nlf_l_multi_0.3.2.torchscript")
|
||||
|
||||
from .motion4d import SMPL_VQVAE, VectorQuantizer, Encoder, Decoder
|
||||
from .mtv import prepare_motion_embeddings
|
||||
|
||||
class DownloadAndLoadNLFModel:
|
||||
@classmethod
|
||||
@@ -38,7 +34,7 @@ class DownloadAndLoadNLFModel:
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
|
||||
def loadmodel(self, url):
|
||||
|
||||
|
||||
if not os.path.exists(local_model_path):
|
||||
log.info(f"Downloading NLF model to: {local_model_path}")
|
||||
import requests
|
||||
@@ -108,7 +104,7 @@ class LoadVQVAE:
|
||||
frame_upsample_rate=[2.0, 2.0],
|
||||
joint_upsample_rate=[1.0, 1.0]
|
||||
)
|
||||
|
||||
|
||||
vqvae = SMPL_VQVAE(motion_encoder, motion_decoder, motion_quant).to(device)
|
||||
vqvae.load_state_dict(vae_sd, strict=True)
|
||||
|
||||
@@ -131,15 +127,6 @@ class MTVCrafterEncodePoses:
|
||||
|
||||
def encode(self, vqvae, poses):
|
||||
|
||||
# import pickle
|
||||
# with open(os.path.join(script_directory, "data", "sampled_data.pkl"), 'rb') as f:
|
||||
# data_list = pickle.load(f)
|
||||
# if not isinstance(data_list, list):
|
||||
# data_list = [data_list]
|
||||
# print(data_list)
|
||||
|
||||
# smpl_poses = data_list[1]['pose']
|
||||
|
||||
global_mean = np.load(os.path.join(script_directory, "data", "mean.npy")) #global_mean.shape: (24, 3)
|
||||
global_std = np.load(os.path.join(script_directory, "data", "std.npy"))
|
||||
|
||||
@@ -153,7 +140,7 @@ class MTVCrafterEncodePoses:
|
||||
|
||||
vqvae.to(device)
|
||||
motion_tokens, vq_loss = vqvae(norm_poses.to(device), return_vq=True)
|
||||
|
||||
|
||||
recon_motion = vqvae(norm_poses.to(device))[0][0].to(dtype=torch.float32).cpu().detach() * global_std + global_mean
|
||||
vqvae.to(offload_device)
|
||||
|
||||
@@ -162,7 +149,7 @@ class MTVCrafterEncodePoses:
|
||||
'global_mean': global_mean,
|
||||
'global_std': global_std
|
||||
}
|
||||
|
||||
|
||||
return poses_dict, recon_motion
|
||||
|
||||
|
||||
@@ -181,10 +168,13 @@ class NLFPredict:
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
|
||||
def predict(self, model, images):
|
||||
|
||||
|
||||
prev_fuser_state = torch._C._jit_texpr_fuser_enabled()
|
||||
model.to(device)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False) # removes warmup delay, may want to enable later
|
||||
pred = model.detect_smpl_batched(images.permute(0, 3, 1, 2).to(device))
|
||||
model.to(offload_device)
|
||||
torch._C._jit_set_texpr_fuser_enabled(prev_fuser_state)
|
||||
|
||||
pred = dict_to_device(pred, offload_device)
|
||||
|
||||
@@ -197,7 +187,7 @@ class NLFPredict:
|
||||
pose_results[key].append(pred[key])
|
||||
else:
|
||||
pose_results[key].append(None)
|
||||
|
||||
|
||||
return (pose_results,)
|
||||
|
||||
class DrawNLFPoses:
|
||||
@@ -208,21 +198,27 @@ class DrawNLFPoses:
|
||||
"width": ("INT", {"default": 512}),
|
||||
"height": ("INT", {"default": 512}),
|
||||
},
|
||||
}
|
||||
"optional": {
|
||||
"stick_width": ("FLOAT", {"default": 4.0, "min": 0.0, "max": 1000.0, "step": 0.01, "tooltip": "Stick width multiplier"}),
|
||||
"point_radius": ("INT", {"default": 5, "min": 1, "max": 10, "step": 1, "tooltip": "Point radius for drawing the pose"}),
|
||||
"style": (["original", "scail"], {"default": "original", "tooltip": "style of the pose drawing"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE", )
|
||||
RETURN_NAMES = ("image",)
|
||||
FUNCTION = "predict"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
|
||||
def predict(self, poses, width, height):
|
||||
def predict(self, poses, width, height, stick_width=1.0, point_radius=2, style="original"):
|
||||
from .draw_pose import get_control_conditions
|
||||
print(type(poses))
|
||||
|
||||
if isinstance(poses, dict):
|
||||
pose_input = poses['joints3d_nonparam'][0] if 'joints3d_nonparam' in poses else poses
|
||||
else:
|
||||
pose_input = poses
|
||||
control_conditions = get_control_conditions(pose_input, height, width)
|
||||
|
||||
control_conditions = get_control_conditions(pose_input, height, width, stick_width=stick_width, point_radius=point_radius, style=style)
|
||||
|
||||
return (control_conditions,)
|
||||
|
||||
|
||||
3427
SCAIL/for_testing/SCAIL_wip_testing.json
Normal file
3427
SCAIL/for_testing/SCAIL_wip_testing.json
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user