1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-28 12:20:55 +03:00

Add v2 sampling nodes for cleaner workflows

No functional changes, just cleaner nodes
This commit is contained in:
kijai
2025-12-13 15:42:35 +02:00
parent 164a6bbebd
commit 78e3e1857c
3 changed files with 202 additions and 155 deletions

155
nodes.py
View File

@@ -1,11 +1,8 @@
import os, gc, math
import torch
import torch.nn.functional as F
import numpy as np
import hashlib
from .wanvideo.schedulers import get_scheduler, scheduler_list
from .utils import(log, clip_encode_image_tiled, add_noise_to_reference_video, set_module_tensor_to_device)
from .taehv import TAEHV
@@ -1927,155 +1924,6 @@ class WanVideoFreeInitArgs:
def process(self, **kwargs):
return (kwargs,)
class WanVideoScheduler: #WIP
@classmethod
def INPUT_TYPES(s):
return {"required": {
"scheduler": (scheduler_list, {"default": "unipc"}),
"steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}),
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}),
"end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"})
},
"optional": {
"sigmas": ("SIGMAS", ),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("SIGMAS", "INT", "FLOAT", scheduler_list, "INT", "INT",)
RETURN_NAMES = ("sigmas", "steps", "shift", "scheduler", "start_step", "end_step")
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
EXPERIMENTAL = True
def process(self, scheduler, steps, start_step, end_step, shift, unique_id, sigmas=None):
sample_scheduler, timesteps, start_idx, end_idx = get_scheduler(
scheduler,
steps,
start_step, end_step, shift,
device,
sigmas=sigmas,
log_timesteps=True)
scheduler_dict = {
"sample_scheduler": sample_scheduler,
"timesteps": timesteps,
}
try:
from server import PromptServer
import io
import base64
import matplotlib.pyplot as plt
except:
PromptServer = None
if unique_id and PromptServer is not None:
try:
# Plot sigmas and save to a buffer
sigmas_np = sample_scheduler.full_sigmas.cpu().numpy()
if not np.isclose(sigmas_np[-1], 0.0, atol=1e-6):
sigmas_np = np.append(sigmas_np, 0.0)
buf = io.BytesIO()
fig = plt.figure(facecolor='#353535')
ax = fig.add_subplot(111)
ax.set_facecolor('#353535') # Set axes background color
x_values = range(0, len(sigmas_np))
ax.plot(x_values, sigmas_np)
# Annotate each sigma value
ax.scatter(x_values, sigmas_np, color='white', s=20, zorder=3) # Small dots at each sigma
for x, y in zip(x_values, sigmas_np):
# Show all annotations if few steps, or just show split step annotations
show_annotation = len(sigmas_np) <= 10
is_split_step = (start_idx > 0 and x == start_idx) or (end_idx != -1 and x == end_idx + 1)
if show_annotation or is_split_step:
color = 'orange'
if is_split_step:
color = 'yellow'
ax.annotate(f"{y:.3f}", (x, y), textcoords="offset points", xytext=(10, 1), ha='center', color=color, fontsize=12)
ax.set_xticks(x_values)
ax.set_title("Sigmas", color='white') # Title font color
ax.set_xlabel("Step", color='white') # X label font color
ax.set_ylabel("Sigma Value", color='white') # Y label font color
ax.tick_params(axis='x', colors='white', labelsize=10) # X tick color
ax.tick_params(axis='y', colors='white', labelsize=10) # Y tick color
# Add split point if end_step is defined
end_idx += 1
if end_idx != -1 and 0 <= end_idx < len(sigmas_np) - 1:
ax.axvline(end_idx, color='red', linestyle='--', linewidth=2, label='end_step split')
# Add split point if start_step is defined
if start_idx > 0 and 0 <= start_idx < len(sigmas_np):
ax.axvline(start_idx, color='green', linestyle='--', linewidth=2, label='start_step split')
if (end_idx != -1 and 0 <= end_idx < len(sigmas_np)) or (start_idx > 0 and 0 <= start_idx < len(sigmas_np)):
handles, labels = ax.get_legend_handles_labels()
if labels:
ax.legend()
if start_idx < end_idx and 0 <= start_idx < len(sigmas_np) and 0 < end_idx < len(sigmas_np):
ax.axvspan(start_idx, end_idx, color='lightblue', alpha=0.1, label='Sampled Range')
plt.tight_layout()
plt.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
# Send as HTML img tag with base64 data
html_img = f"<img src='data:image/png;base64,{img_base64}' alt='Sigmas Plot' style='max-width:100%; height:100%; overflow:hidden; display:block;'>"
PromptServer.instance.send_progress_text(html_img, unique_id)
except Exception as e:
print("Failed to send sigmas plot:", e)
pass
return (sigmas, steps, shift, scheduler_dict, start_step, end_step)
class WanVideoSchedulerSA_ODE:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"use_adaptive_order": ("BOOLEAN", {"default": False, "tooltip": "Use adaptive order"}),
"use_velocity_smoothing": ("BOOLEAN", {"default": True, "tooltip": "Use velocity smoothing"}),
"convergence_threshold": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Convergence threshold for velocity smoothing"}),
"smoothing_factor": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001, "tooltip": "Smoothing factor for velocity smoothing"}),
"steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}),
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}),
"end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"})
},
"optional": {
"sigmas": ("SIGMAS", ),
},
}
RETURN_TYPES = ("SIGMAS", "INT", "FLOAT", scheduler_list, "INT", "INT",)
RETURN_NAMES = ("sigmas", "steps", "shift", "scheduler", "start_step", "end_step")
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
EXPERIMENTAL = True
def process(self, steps, start_step, end_step, shift, use_adaptive_order, use_velocity_smoothing, convergence_threshold, smoothing_factor, sigmas=None):
sample_scheduler, timesteps, _, _ = get_scheduler(
scheduler="sa_ode_stable/lowstep",
steps=steps,
start_step=start_step, end_step=end_step, shift=shift,
device=device,
sigmas=sigmas,
log_timesteps=True,
use_adaptive_order=use_adaptive_order,
use_velocity_smoothing=use_velocity_smoothing,
convergence_threshold=convergence_threshold,
smoothing_factor=smoothing_factor
)
scheduler_dict = {
"sample_scheduler": sample_scheduler,
"timesteps": timesteps,
}
return (sigmas, steps, shift, scheduler_dict, start_step, end_step)
rope_functions = ["default", "comfy", "comfy_chunked"]
class WanVideoRoPEFunction:
@@ -2396,7 +2244,6 @@ NODE_CLASS_MAPPINGS = {
"WanVideoBlockList": WanVideoBlockList,
"WanVideoTextEncodeCached": WanVideoTextEncodeCached,
"WanVideoAddExtraLatent": WanVideoAddExtraLatent,
"WanVideoScheduler": WanVideoScheduler,
"WanVideoAddStandInLatent": WanVideoAddStandInLatent,
"WanVideoAddControlEmbeds": WanVideoAddControlEmbeds,
"WanVideoAddMTVMotion": WanVideoAddMTVMotion,
@@ -2404,7 +2251,6 @@ NODE_CLASS_MAPPINGS = {
"WanVideoAddPusaNoise": WanVideoAddPusaNoise,
"WanVideoAnimateEmbeds": WanVideoAnimateEmbeds,
"WanVideoAddLucyEditLatents": WanVideoAddLucyEditLatents,
"WanVideoSchedulerSA_ODE": WanVideoSchedulerSA_ODE,
"WanVideoAddBindweaveEmbeds": WanVideoAddBindweaveEmbeds,
"TextImageEncodeQwenVL": TextImageEncodeQwenVL,
"WanVideoUniLumosEmbeds": WanVideoUniLumosEmbeds,
@@ -2447,7 +2293,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoAddPusaNoise": "WanVideo Add Pusa Noise",
"WanVideoAnimateEmbeds": "WanVideo Animate Embeds",
"WanVideoAddLucyEditLatents": "WanVideo Add LucyEdit Latents",
"WanVideoSchedulerSA_ODE": "WanVideo Scheduler SA-ODE",
"WanVideoAddBindweaveEmbeds": "WanVideo Add Bindweave Embeds",
"WanVideoUniLumosEmbeds": "WanVideo UniLumos Embeds",
"WanVideoAddTTMLatents": "WanVideo Add TTMLatents",

View File

@@ -1055,6 +1055,8 @@ class WanVideoSampler:
rope_function = "comfy" # only works with this currently
freqs = None
riflex_freq_index = 0 if riflex_freq_index is None else riflex_freq_index
transformer.rope_embedder.k = None
transformer.rope_embedder.num_frames = None
d = transformer.dim // transformer.num_heads
@@ -3274,13 +3276,211 @@ class WanVideoSamplerFromSettings(WanVideoSampler):
def process(self, sampler_inputs):
return super().process(**sampler_inputs)
class WanVideoSamplerExtraArgs():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
},
"optional": {
"riflex_freq_index": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1, "tooltip": "Frequency index for RIFLEX, disabled when 0, default 6. Allows for new frames to be generated after without looping"}),
"feta_args": ("FETAARGS", ),
"context_options": ("WANVIDCONTEXT", ),
"cache_args": ("CACHEARGS", ),
"slg_args": ("SLGARGS", ),
"rope_function": (rope_functions, {"default": "comfy", "tooltip": "Comfy's RoPE implementation doesn't use complex numbers and can thus be compiled, that should be a lot faster when using torch.compile. Chunked version has reduced peak VRAM usage when not using torch.compile"}),
"loop_args": ("LOOPARGS", ),
"experimental_args": ("EXPERIMENTALARGS", ),
"unianimate_poses": ("UNIANIMATE_POSE", ),
"fantasytalking_embeds": ("FANTASYTALKING_EMBEDS", ),
"uni3c_embeds": ("UNI3C_EMBEDS", ),
"multitalk_embeds": ("MULTITALK_EMBEDS", ),
}
}
RETURN_TYPES = ("WANVIDSAMPLEREXTRAARGS",)
RETURN_NAMES = ("extra_args", )
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
def process(self, *args, **kwargs):
return kwargs,
class WanVideoSamplerv2(WanVideoSampler):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("WANVIDEOMODEL",),
"image_embeds": ("WANVIDIMAGE_EMBEDS", ),
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}),
"scheduler": ("WANVIDEOSCHEDULER",),
},
"optional": {
"text_embeds": ("WANVIDEOTEXTEMBEDS", ),
"samples": ("LATENT", {"tooltip": "init Latents to use for video2video process"} ),
"add_noise_to_samples": ("BOOLEAN", {"default": False, "tooltip": "Add noise to the samples before sampling, needed for video2video sampling when starting from clean video"}),
"extra_args": ("WANVIDSAMPLEREXTRAARGS", ),
}
}
def process(self, *args, extra_args=None, **kwargs):
import inspect
params = inspect.signature(WanVideoSampler.process).parameters
args_dict = {name: kwargs.get(name, param.default if param.default is not inspect.Parameter.empty else None)
for name, param in params.items() if name != "self"}
if extra_args is not None:
args_dict.update(extra_args)
return super().process(**args_dict)
class WanVideoScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"scheduler": (scheduler_list, {"default": "unipc"}),
"steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}),
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}),
"end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"})
},
"optional": {
"sigmas": ("SIGMAS", ),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("SIGMAS", "INT", "FLOAT", scheduler_list, "INT", "INT",)
RETURN_NAMES = ("sigmas", "steps", "shift", "scheduler", "start_step", "end_step")
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
EXPERIMENTAL = True
def process(self, scheduler, steps, start_step, end_step, shift, unique_id, sigmas=None):
sample_scheduler, timesteps, start_idx, end_idx = get_scheduler(
scheduler, steps, start_step, end_step, shift, device, sigmas=sigmas, log_timesteps=True)
scheduler_dict = {
"sample_scheduler": sample_scheduler,
"timesteps": timesteps,
}
try:
from server import PromptServer
import io
import base64
import matplotlib.pyplot as plt
except:
PromptServer = None
if unique_id and PromptServer is not None:
try:
# Plot sigmas and save to a buffer
sigmas_np = sample_scheduler.full_sigmas.cpu().numpy()
if not np.isclose(sigmas_np[-1], 0.0, atol=1e-6):
sigmas_np = np.append(sigmas_np, 0.0)
buf = io.BytesIO()
fig = plt.figure(facecolor='#353535')
ax = fig.add_subplot(111)
ax.set_facecolor('#353535') # Set axes background color
x_values = range(0, len(sigmas_np))
ax.plot(x_values, sigmas_np)
# Annotate each sigma value
ax.scatter(x_values, sigmas_np, color='white', s=20, zorder=3) # Small dots at each sigma
for x, y in zip(x_values, sigmas_np):
# Show all annotations if few steps, or just show split step annotations
show_annotation = len(sigmas_np) <= 10
is_split_step = (start_idx > 0 and x == start_idx) or (end_idx != -1 and x == end_idx + 1)
if show_annotation or is_split_step:
color = 'orange'
if is_split_step:
color = 'yellow'
ax.annotate(f"{y:.3f}", (x, y), textcoords="offset points", xytext=(10, 1), ha='center', color=color, fontsize=12)
ax.set_xticks(x_values)
ax.set_title("Sigmas", color='white') # Title font color
ax.set_xlabel("Step", color='white') # X label font color
ax.set_ylabel("Sigma Value", color='white') # Y label font color
ax.tick_params(axis='x', colors='white', labelsize=10) # X tick color
ax.tick_params(axis='y', colors='white', labelsize=10) # Y tick color
# Add split point if end_step is defined
end_idx += 1
if end_idx != -1 and 0 <= end_idx < len(sigmas_np) - 1:
ax.axvline(end_idx, color='red', linestyle='--', linewidth=2, label='end_step split')
# Add split point if start_step is defined
if start_idx > 0 and 0 <= start_idx < len(sigmas_np):
ax.axvline(start_idx, color='green', linestyle='--', linewidth=2, label='start_step split')
if (end_idx != -1 and 0 <= end_idx < len(sigmas_np)) or (start_idx > 0 and 0 <= start_idx < len(sigmas_np)):
handles, labels = ax.get_legend_handles_labels()
if labels:
ax.legend()
if start_idx < end_idx and 0 <= start_idx < len(sigmas_np) and 0 < end_idx < len(sigmas_np):
ax.axvspan(start_idx, end_idx, color='lightblue', alpha=0.1, label='Sampled Range')
plt.tight_layout()
plt.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode('utf-8')
buf.close()
# Send as HTML img tag with base64 data
html_img = f"<img src='data:image/png;base64,{img_base64}' alt='Sigmas Plot' style='max-width:100%; height:100%; overflow:hidden; display:block;'>"
PromptServer.instance.send_progress_text(html_img, unique_id)
except Exception as e:
log.error(f"Failed to send sigmas plot: {e}")
pass
return (sigmas, steps, shift, scheduler_dict, start_step, end_step)
class WanVideoSchedulerv2(WanVideoScheduler):
@classmethod
def INPUT_TYPES(s):
return {"required": {
"scheduler": (scheduler_list, {"default": "unipc"}),
"steps": ("INT", {"default": 30, "min": 1, "tooltip": "Number of steps for the scheduler"}),
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"start_step": ("INT", {"default": 0, "min": 0, "tooltip": "Starting step for the scheduler"}),
"end_step": ("INT", {"default": -1, "min": -1, "tooltip": "Ending step for the scheduler"})
},
"optional": {
"sigmas": ("SIGMAS", ),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("WANVIDEOSCHEDULER",)
RETURN_NAMES = ("scheduler",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
EXPERIMENTAL = True
def process(self, *args, **kwargs):
sigmas, steps, shift, scheduler_dict, start_step, end_step = super().process(*args, **kwargs)
return scheduler_dict,
NODE_CLASS_MAPPINGS = {
"WanVideoSampler": WanVideoSampler,
"WanVideoSamplerSettings": WanVideoSamplerSettings,
"WanVideoSamplerFromSettings": WanVideoSamplerFromSettings,
"WanVideoSamplerv2": WanVideoSamplerv2,
"WanVideoSamplerExtraArgs": WanVideoSamplerExtraArgs,
"WanVideoScheduler": WanVideoScheduler,
"WanVideoSchedulerv2": WanVideoSchedulerv2,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoSampler": "WanVideo Sampler",
"WanVideoSamplerSettings": "WanVideo Sampler Settings",
"WanVideoSamplerFromSettings": "WanVideo Sampler From Settings",
"WanVideoSamplerv2": "WanVideo Sampler v2",
"WanVideoSamplerExtraArgs": "WanVideoSampler v2 Extra Args",
"WanVideoScheduler": "WanVideo Scheduler",
"WanVideoSchedulerv2": "WanVideo Scheduler v2",
}

View File

@@ -41,6 +41,8 @@ def _apply_custom_sigmas(sample_scheduler, sigmas, device):
def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transformer_dim=5120, flowedit_args=None, denoise_strength=1.0, sigmas=None, log_timesteps=False, **kwargs):
timesteps = None
if sigmas is not None:
steps = len(sigmas) - 1
if scheduler == 'vibt_unipc':
sample_scheduler = ViBTScheduler()
sample_scheduler.set_parameters(shift=shift)