1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00

Merge branch 'main' into longvie2

This commit is contained in:
kijai
2025-12-29 15:03:53 +02:00
44 changed files with 1463 additions and 10758 deletions

View File

@@ -37,15 +37,16 @@ class WanVideoLongCatAvatarExtendEmbeds(io.ComfyNode):
new_audio_embed = audio_embeds.copy()
audio_features = torch.stack(new_audio_embed["audio_features"])
num_audio_features = audio_features.shape[1]
if audio_features.shape[1] < frames_processed + num_frames:
deficit = frames_processed + num_frames - audio_features.shape[1]
if if_not_enough_audio == "pad_with_start":
pad = audio_features[:, :1].repeat(1, deficit, 1, 1, 1)
pad = audio_features[:, :1].repeat(1, deficit, 1, 1)
audio_features = torch.cat([audio_features, pad], dim=1)
elif if_not_enough_audio == "mirror_from_end":
to_add = audio_features[:, -deficit:, :].flip(dims=[1])
audio_features = torch.cat([audio_features, to_add], dim=1)
log.info(f"Not enough audio features, extended from {new_audio_embed['audio_features'].shape[1]} to {audio_features.shape[1]} frames.")
log.warning(f"Not enough audio features, padded with strategy '{if_not_enough_audio}' from {num_audio_features} to {audio_features.shape[1]} frames")
ref_target_masks = new_audio_embed.get("ref_target_masks", None)
if ref_target_masks is not None:

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@@ -1,780 +0,0 @@
{
"id": "8b7a9a57-2303-4ef5-9fc2-bf41713bd1fc",
"revision": 0,
"last_node_id": 46,
"last_link_id": 58,
"nodes": [
{
"id": 33,
"type": "Note",
"pos": [
227.3764190673828,
-205.28524780273438
],
"size": [
351.70458984375,
88
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {},
"widgets_values": [
"Models:\nhttps://huggingface.co/Kijai/WanVideo_comfy/tree/main"
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 11,
"type": "LoadWanVideoT5TextEncoder",
"pos": [
224.15325927734375,
-34.481563568115234
],
"size": [
377.1661376953125,
130
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "wan_t5_model",
"type": "WANTEXTENCODER",
"slot_index": 0,
"links": [
15
]
}
],
"properties": {
"Node name for S&R": "LoadWanVideoT5TextEncoder"
},
"widgets_values": [
"umt5-xxl-enc-bf16.safetensors",
"bf16",
"offload_device",
"disabled"
]
},
{
"id": 28,
"type": "WanVideoDecode",
"pos": [
1692.973876953125,
-404.8614501953125
],
"size": [
315,
174
],
"flags": {},
"order": 12,
"mode": 0,
"inputs": [
{
"name": "vae",
"type": "WANVAE",
"link": 43
},
{
"name": "samples",
"type": "LATENT",
"link": 33
}
],
"outputs": [
{
"name": "images",
"type": "IMAGE",
"slot_index": 0,
"links": [
48
]
}
],
"properties": {
"Node name for S&R": "WanVideoDecode"
},
"widgets_values": [
true,
272,
272,
144,
128
]
},
{
"id": 38,
"type": "WanVideoVAELoader",
"pos": [
1687.4093017578125,
-582.2750854492188
],
"size": [
416.25482177734375,
82
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "vae",
"type": "WANVAE",
"slot_index": 0,
"links": [
43
]
}
],
"properties": {
"Node name for S&R": "WanVideoVAELoader"
},
"widgets_values": [
"wanvideo\\Wan2_1_VAE_bf16.safetensors",
"bf16"
]
},
{
"id": 42,
"type": "GetImageSizeAndCount",
"pos": [
1708.7301025390625,
-140.99705505371094
],
"size": [
277.20001220703125,
86
],
"flags": {},
"order": 13,
"mode": 0,
"inputs": [
{
"name": "image",
"type": "IMAGE",
"link": 48
}
],
"outputs": [
{
"name": "image",
"type": "IMAGE",
"slot_index": 0,
"links": [
56
]
},
{
"label": "832 width",
"name": "width",
"type": "INT",
"links": null
},
{
"label": "480 height",
"name": "height",
"type": "INT",
"links": null
},
{
"label": "257 count",
"name": "count",
"type": "INT",
"links": null
}
],
"properties": {
"Node name for S&R": "GetImageSizeAndCount"
},
"widgets_values": []
},
{
"id": 16,
"type": "WanVideoTextEncode",
"pos": [
675.8850708007812,
-36.032100677490234
],
"size": [
420.30511474609375,
261.5306701660156
],
"flags": {},
"order": 10,
"mode": 0,
"inputs": [
{
"name": "t5",
"type": "WANTEXTENCODER",
"link": 15
},
{
"name": "model_to_offload",
"shape": 7,
"type": "WANVIDEOMODEL",
"link": null
}
],
"outputs": [
{
"name": "text_embeds",
"type": "WANVIDEOTEXTEMBEDS",
"slot_index": 0,
"links": [
30
]
}
],
"properties": {
"Node name for S&R": "WanVideoTextEncode"
},
"widgets_values": [
"high quality nature video featuring a red panda balancing on a bamboo stem while a bird lands on it's head, on the background there is a waterfall",
"色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
true
]
},
{
"id": 30,
"type": "VHS_VideoCombine",
"pos": [
2127.120849609375,
-511.9014587402344
],
"size": [
873.2135620117188,
840.2385864257812
],
"flags": {},
"order": 14,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 56
},
{
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": null
},
{
"name": "meta_batch",
"shape": 7,
"type": "VHS_BatchManager",
"link": null
},
{
"name": "vae",
"shape": 7,
"type": "VAE",
"link": null
}
],
"outputs": [
{
"name": "Filenames",
"type": "VHS_FILENAMES",
"links": null
}
],
"properties": {
"Node name for S&R": "VHS_VideoCombine"
},
"widgets_values": {
"frame_rate": 16,
"loop_count": 0,
"filename_prefix": "WanVideo2_1_T2V",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 19,
"save_metadata": true,
"trim_to_audio": false,
"pingpong": false,
"save_output": true,
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"filename": "WanVideo2_1_T2V_00412.mp4",
"subfolder": "",
"type": "output",
"format": "video/h264-mp4",
"frame_rate": 16,
"workflow": "WanVideo2_1_T2V_00412.png",
"fullpath": "N:\\AI\\ComfyUI\\output\\WanVideo2_1_T2V_00412.mp4"
}
}
}
},
{
"id": 37,
"type": "WanVideoEmptyEmbeds",
"pos": [
1305.26708984375,
-571.7843627929688
],
"size": [
315,
106
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "image_embeds",
"type": "WANVIDIMAGE_EMBEDS",
"links": [
42
]
}
],
"properties": {
"Node name for S&R": "WanVideoEmptyEmbeds"
},
"widgets_values": [
832,
480,
257
]
},
{
"id": 35,
"type": "WanVideoTorchCompileSettings",
"pos": [
193.47103881835938,
-614.6900024414062
],
"size": [
390.5999755859375,
178
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "torch_compile_args",
"type": "WANCOMPILEARGS",
"slot_index": 0,
"links": []
}
],
"properties": {
"Node name for S&R": "WanVideoTorchCompileSettings"
},
"widgets_values": [
"inductor",
false,
"default",
false,
64,
true
]
},
{
"id": 45,
"type": "WanVideoTeaCache",
"pos": [
931.4036865234375,
-792.5159912109375
],
"size": [
315,
154
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "cache_args",
"type": "CACHEARGS",
"links": [
58
]
}
],
"properties": {
"Node name for S&R": "WanVideoTeaCache"
},
"widgets_values": [
0.10000000000000002,
1,
-1,
"offload_device",
true
]
},
{
"id": 36,
"type": "Note",
"pos": [
796.0189208984375,
-521.5020751953125
],
"size": [
298.2554016113281,
108.62744140625
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {},
"widgets_values": [
"sdpa should work too, haven't tested flaash\n\nfp8_fast seems to cause huge quality degradation"
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 46,
"type": "Note",
"pos": [
937.9556274414062,
-940.750244140625
],
"size": [
297.4364013671875,
88
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [],
"outputs": [],
"properties": {},
"widgets_values": [
"TeaCache with context windows is VERY experimental and lower values than normal should be used."
],
"color": "#432",
"bgcolor": "#653"
},
{
"id": 27,
"type": "WanVideoSampler",
"pos": [
1315.2401123046875,
-401.48028564453125
],
"size": [
315,
574.1923217773438
],
"flags": {},
"order": 11,
"mode": 0,
"inputs": [
{
"name": "model",
"type": "WANVIDEOMODEL",
"link": 29
},
{
"name": "text_embeds",
"type": "WANVIDEOTEXTEMBEDS",
"link": 30
},
{
"name": "image_embeds",
"type": "WANVIDIMAGE_EMBEDS",
"link": 42
},
{
"name": "samples",
"shape": 7,
"type": "LATENT",
"link": null
},
{
"name": "feta_args",
"shape": 7,
"type": "FETAARGS",
"link": null
},
{
"name": "context_options",
"shape": 7,
"type": "WANVIDCONTEXT",
"link": 57
},
{
"name": "cache_args",
"shape": 7,
"type": "CACHEARGS",
"link": 58
},
{
"name": "flowedit_args",
"shape": 7,
"type": "FLOWEDITARGS",
"link": null
},
{
"name": "slg_args",
"shape": 7,
"type": "SLGARGS",
"link": null
},
{
"name": "loop_args",
"shape": 7,
"type": "LOOPARGS",
"link": null
}
],
"outputs": [
{
"name": "samples",
"type": "LATENT",
"slot_index": 0,
"links": [
33
]
}
],
"properties": {
"Node name for S&R": "WanVideoSampler"
},
"widgets_values": [
30,
6,
5,
1057359483639288,
"fixed",
true,
"unipc",
0,
1,
"",
"comfy"
]
},
{
"id": 43,
"type": "WanVideoContextOptions",
"pos": [
1307.9542236328125,
-855.8865356445312
],
"size": [
315,
226
],
"flags": {},
"order": 8,
"mode": 0,
"inputs": [
{
"name": "vae",
"shape": 7,
"type": "WANVAE",
"link": null
}
],
"outputs": [
{
"name": "context_options",
"type": "WANVIDCONTEXT",
"slot_index": 0,
"links": [
57
]
}
],
"properties": {
"Node name for S&R": "WanVideoContextOptions"
},
"widgets_values": [
"uniform_standard",
81,
4,
16,
true,
false,
6,
2
]
},
{
"id": 22,
"type": "WanVideoModelLoader",
"pos": [
620.3950805664062,
-357.8426818847656
],
"size": [
477.4410095214844,
226.43276977539062
],
"flags": {},
"order": 9,
"mode": 0,
"inputs": [
{
"name": "compile_args",
"shape": 7,
"type": "WANCOMPILEARGS",
"link": null
},
{
"name": "block_swap_args",
"shape": 7,
"type": "BLOCKSWAPARGS",
"link": null
},
{
"name": "lora",
"shape": 7,
"type": "WANVIDLORA",
"link": null
},
{
"name": "vram_management_args",
"shape": 7,
"type": "VRAM_MANAGEMENTARGS",
"link": null
}
],
"outputs": [
{
"name": "model",
"type": "WANVIDEOMODEL",
"slot_index": 0,
"links": [
29
]
}
],
"properties": {
"Node name for S&R": "WanVideoModelLoader"
},
"widgets_values": [
"WanVideo\\wan2.1_t2v_1.3B_fp16.safetensors",
"fp16",
"disabled",
"offload_device",
"sdpa"
]
}
],
"links": [
[
15,
11,
0,
16,
0,
"WANTEXTENCODER"
],
[
29,
22,
0,
27,
0,
"WANVIDEOMODEL"
],
[
30,
16,
0,
27,
1,
"WANVIDEOTEXTEMBEDS"
],
[
33,
27,
0,
28,
1,
"LATENT"
],
[
42,
37,
0,
27,
2,
"WANVIDIMAGE_EMBEDS"
],
[
43,
38,
0,
28,
0,
"VAE"
],
[
48,
28,
0,
42,
0,
"IMAGE"
],
[
56,
42,
0,
30,
0,
"IMAGE"
],
[
57,
43,
0,
27,
5,
"WANVIDCONTEXT"
],
[
58,
45,
0,
27,
6,
"TEACACHEARGS"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.8140274938684471,
"offset": [
-122.25834160503663,
993.5739491626379
]
},
"node_versions": {
"ComfyUI-WanVideoWrapper": "5a2383621a05825d0d0437781afcb8552d9590fd",
"ComfyUI-KJNodes": "a5bd3c86c8ed6b83c55c2d0e7a59515b15a0137f",
"ComfyUI-VideoHelperSuite": "0a75c7958fe320efcb052f1d9f8451fd20c730a8"
},
"VHS_latentpreview": true,
"VHS_latentpreviewrate": 0,
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"version": 0.4
}

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

483
multitalk/multitalk_loop.py Normal file
View File

@@ -0,0 +1,483 @@
import torch
import os
import gc
from PIL import Image
import numpy as np
from ..latent_preview import prepare_callback
from ..wanvideo.schedulers import get_scheduler
from .multitalk import timestep_transform, add_noise
from ..utils import log, print_memory, temporal_score_rescaling, offload_transformer, init_blockswap
from comfy.utils import load_torch_file
from ..nodes_model_loading import load_weights
from ..HuMo.nodes import get_audio_emb_window
import comfy.model_management as mm
from tqdm import tqdm
import copy
VAE_STRIDE = (4, 8, 8)
PATCH_SIZE = (1, 2, 2)
vae_upscale_factor = 16
script_directory = os.path.dirname(os.path.abspath(__file__))
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
def multitalk_loop(self, **kwargs):
# Unpack kwargs into local variables
(latent, total_steps, steps, start_step, end_step, shift, cfg, denoise_strength,
sigmas, weight_dtype, transformer, patcher, block_swap_args, model, vae, dtype,
scheduler, scheduler_step_args, text_embeds, image_embeds, multitalk_embeds,
multitalk_audio_embeds, unianim_data, dwpose_data, unianimate_poses, uni3c_embeds,
humo_image_cond, humo_image_cond_neg, humo_audio, humo_reference_count,
add_noise_to_samples, audio_stride, use_tsr, tsr_k, tsr_sigma, fantasy_portrait_input,
noise, timesteps, force_offload, add_cond, control_latents, audio_proj,
control_camera_latents, samples, masks, seed_g, gguf_reader, predict_func
) = (kwargs.get(k) for k in (
'latent', 'total_steps', 'steps', 'start_step', 'end_step', 'shift', 'cfg',
'denoise_strength', 'sigmas', 'weight_dtype', 'transformer', 'patcher',
'block_swap_args', 'model', 'vae', 'dtype', 'scheduler', 'scheduler_step_args',
'text_embeds', 'image_embeds', 'multitalk_embeds', 'multitalk_audio_embeds',
'unianim_data', 'dwpose_data', 'unianimate_poses', 'uni3c_embeds',
'humo_image_cond', 'humo_image_cond_neg', 'humo_audio', 'humo_reference_count',
'add_noise_to_samples', 'audio_stride', 'use_tsr', 'tsr_k', 'tsr_sigma',
'fantasy_portrait_input', 'noise', 'timesteps', 'force_offload', 'add_cond',
'control_latents', 'audio_proj', 'control_camera_latents', 'samples', 'masks',
'seed_g', 'gguf_reader', 'predict_with_cfg'
))
mode = image_embeds.get("multitalk_mode", "multitalk")
if mode == "auto":
mode = transformer.multitalk_model_type.lower()
log.info(f"Multitalk mode: {mode}")
cond_frame = None
offload = image_embeds.get("force_offload", False)
offloaded = False
tiled_vae = image_embeds.get("tiled_vae", False)
frame_num = clip_length = image_embeds.get("frame_window_size", 81)
clip_embeds = image_embeds.get("clip_context", None)
if clip_embeds is not None:
clip_embeds = clip_embeds.to(dtype)
colormatch = image_embeds.get("colormatch", "disabled")
motion_frame = image_embeds.get("motion_frame", 25)
target_w = image_embeds.get("target_w", None)
target_h = image_embeds.get("target_h", None)
original_images = cond_image = image_embeds.get("multitalk_start_image", None)
if original_images is None:
original_images = torch.zeros([noise.shape[0], 1, target_h, target_w], device=device)
output_path = image_embeds.get("output_path", "")
img_counter = 0
if len(multitalk_embeds['audio_features'])==2 and (multitalk_embeds['ref_target_masks'] is None):
face_scale = 0.1
x_min, x_max = int(target_h * face_scale), int(target_h * (1 - face_scale))
lefty_min, lefty_max = int((target_w//2) * face_scale), int((target_w//2) * (1 - face_scale))
righty_min, righty_max = int((target_w//2) * face_scale + (target_w//2)), int((target_w//2) * (1 - face_scale) + (target_w//2))
human_mask1, human_mask2 = (torch.zeros([target_h, target_w]) for _ in range(2))
human_mask1[x_min:x_max, lefty_min:lefty_max] = 1
human_mask2[x_min:x_max, righty_min:righty_max] = 1
background_mask = torch.where((human_mask1 + human_mask2) > 0, torch.tensor(0), torch.tensor(1))
human_masks = [human_mask1, human_mask2, background_mask]
ref_target_masks = torch.stack(human_masks, dim=0)
multitalk_embeds['ref_target_masks'] = ref_target_masks
gen_video_list = []
is_first_clip = True
arrive_last_frame = False
cur_motion_frames_num = 1
audio_start_idx = iteration_count = step_iteration_count = 0
audio_end_idx = (audio_start_idx + clip_length) * audio_stride
indices = (torch.arange(4 + 1) - 2) * 1
current_condframe_index = 0
audio_embedding = multitalk_audio_embeds
human_num = len(audio_embedding)
audio_embs = None
cond_frame = None
uni3c_data = None
if uni3c_embeds is not None:
transformer.controlnet = uni3c_embeds["controlnet"]
uni3c_data = uni3c_embeds.copy()
encoded_silence = None
try:
silence_path = os.path.join(script_directory, "encoded_silence.safetensors")
encoded_silence = load_torch_file(silence_path)["audio_emb"].to(dtype)
except:
log.warning("No encoded silence file found, padding with end of audio embedding instead.")
total_frames = len(audio_embedding[0])
estimated_iterations = total_frames // (frame_num - motion_frame) + 1
callback = prepare_callback(patcher, estimated_iterations)
if frame_num >= total_frames:
arrive_last_frame = True
estimated_iterations = 1
log.info(f"Sampling {total_frames} frames in {estimated_iterations} windows, at {latent.shape[3]*vae_upscale_factor}x{latent.shape[2]*vae_upscale_factor} with {steps} steps")
while True: # start video generation iteratively
self.cache_state = [None, None]
cur_motion_frames_latent_num = int(1 + (cur_motion_frames_num-1) // 4)
if mode == "infinitetalk":
cond_image = original_images[:, :, current_condframe_index:current_condframe_index+1] if cond_image is not None else None
if multitalk_embeds is not None:
audio_embs = []
# split audio with window size
for human_idx in range(human_num):
center_indices = torch.arange(audio_start_idx, audio_end_idx, audio_stride).unsqueeze(1) + indices.unsqueeze(0)
center_indices = torch.clamp(center_indices, min=0, max=audio_embedding[human_idx].shape[0]-1)
audio_emb = audio_embedding[human_idx][center_indices].unsqueeze(0).to(device)
audio_embs.append(audio_emb)
audio_embs = torch.concat(audio_embs, dim=0).to(dtype)
h, w = (cond_image.shape[-2], cond_image.shape[-1]) if cond_image is not None else (target_h, target_w)
lat_h, lat_w = h // VAE_STRIDE[1], w // VAE_STRIDE[2]
latent_frame_num = (frame_num - 1) // 4 + 1
noise = torch.randn(
16, latent_frame_num,
lat_h, lat_w, dtype=torch.float32, device=torch.device("cpu"), generator=seed_g).to(device)
# Calculate the correct latent slice based on current iteration
if is_first_clip:
latent_start_idx = 0
latent_end_idx = noise.shape[1]
else:
new_frames_per_iteration = frame_num - motion_frame
new_latent_frames_per_iteration = ((new_frames_per_iteration - 1) // 4 + 1)
latent_start_idx = iteration_count * new_latent_frames_per_iteration
latent_end_idx = latent_start_idx + noise.shape[1]
if samples is not None:
noise_mask = samples.get("noise_mask", None)
input_samples = samples["samples"]
if input_samples is not None:
input_samples = input_samples.squeeze(0).to(noise)
# Check if we have enough frames in input_samples
if latent_end_idx > input_samples.shape[1]:
# We need more frames than available - pad the input_samples at the end
pad_length = latent_end_idx - input_samples.shape[1]
last_frame = input_samples[:, -1:].repeat(1, pad_length, 1, 1)
input_samples = torch.cat([input_samples, last_frame], dim=1)
input_samples = input_samples[:, latent_start_idx:latent_end_idx]
if noise_mask is not None:
original_image = input_samples.to(device)
assert input_samples.shape[1] == noise.shape[1], f"Slice mismatch: {input_samples.shape[1]} vs {noise.shape[1]}"
if add_noise_to_samples:
latent_timestep = timesteps[0]
noise = noise * latent_timestep / 1000 + (1 - latent_timestep / 1000) * input_samples
else:
noise = input_samples
# diff diff prep
if noise_mask is not None:
if len(noise_mask.shape) == 4:
noise_mask = noise_mask.squeeze(1)
if audio_end_idx > noise_mask.shape[0]:
noise_mask = noise_mask.repeat(audio_end_idx // noise_mask.shape[0], 1, 1)
noise_mask = noise_mask[audio_start_idx:audio_end_idx]
noise_mask = torch.nn.functional.interpolate(
noise_mask.unsqueeze(0).unsqueeze(0), # Add batch and channel dims [1,1,T,H,W]
size=(noise.shape[1], noise.shape[2], noise.shape[3]),
mode='trilinear',
align_corners=False
).repeat(1, noise.shape[0], 1, 1, 1)
thresholds = torch.arange(len(timesteps), dtype=original_image.dtype) / len(timesteps)
thresholds = thresholds.reshape(-1, 1, 1, 1, 1).to(device)
masks = (1-noise_mask.repeat(len(timesteps), 1, 1, 1, 1).to(device)) > thresholds
# zero padding and vae encode for img cond
if cond_image is not None or cond_frame is not None:
cond_ = cond_image if (is_first_clip or humo_image_cond is None) else cond_frame
cond_frame_num = cond_.shape[2]
video_frames = torch.zeros(1, 3, frame_num-cond_frame_num, target_h, target_w, device=device, dtype=vae.dtype)
padding_frames_pixels_values = torch.concat([cond_.to(device, vae.dtype), video_frames], dim=2)
# encode
vae.to(device)
y = vae.encode(padding_frames_pixels_values, device=device, tiled=tiled_vae, pbar=False).to(dtype)[0]
if mode == "multitalk":
latent_motion_frames = y[:, :cur_motion_frames_latent_num] # C T H W
else:
cond_ = cond_image if is_first_clip else cond_frame
latent_motion_frames = vae.encode(cond_.to(device, vae.dtype), device=device, tiled=tiled_vae, pbar=False).to(dtype)[0]
vae.to(offload_device)
#motion_frame_index = cur_motion_frames_latent_num if mode == "infinitetalk" else 1
msk = torch.zeros(4, latent_frame_num, lat_h, lat_w, device=device, dtype=dtype)
msk[:, :1] = 1
y = torch.cat([msk, y]) # 4+C T H W
mm.soft_empty_cache()
else:
y = None
latent_motion_frames = noise[:, :1]
partial_humo_cond_input = partial_humo_cond_neg_input = partial_humo_audio = partial_humo_audio_neg = None
if humo_image_cond is not None:
partial_humo_cond_input = humo_image_cond[:, :latent_frame_num]
partial_humo_cond_neg_input = humo_image_cond_neg[:, :latent_frame_num]
if y is not None:
partial_humo_cond_input[:, :1] = y[:, :1]
if humo_reference_count > 0:
partial_humo_cond_input[:, -humo_reference_count:] = humo_image_cond[:, -humo_reference_count:]
partial_humo_cond_neg_input[:, -humo_reference_count:] = humo_image_cond_neg[:, -humo_reference_count:]
if humo_audio is not None:
if is_first_clip:
audio_embs = None
partial_humo_audio, _ = get_audio_emb_window(humo_audio, frame_num, frame0_idx=audio_start_idx)
#zero_audio_pad = torch.zeros(humo_reference_count, *partial_humo_audio.shape[1:], device=partial_humo_audio.device, dtype=partial_humo_audio.dtype)
partial_humo_audio[-humo_reference_count:] = 0
partial_humo_audio_neg = torch.zeros_like(partial_humo_audio, device=partial_humo_audio.device, dtype=partial_humo_audio.dtype)
if scheduler == "multitalk":
timesteps = list(np.linspace(1000, 1, steps, dtype=np.float32))
timesteps.append(0.)
timesteps = [torch.tensor([t], device=device) for t in timesteps]
timesteps = [timestep_transform(t, shift=shift, num_timesteps=1000) for t in timesteps]
else:
if isinstance(scheduler, dict):
sample_scheduler = copy.deepcopy(scheduler["sample_scheduler"])
timesteps = scheduler["timesteps"]
else:
sample_scheduler, timesteps,_,_ = get_scheduler(scheduler, total_steps, start_step, end_step, shift, device, transformer.dim, denoise_strength, sigmas=sigmas)
timesteps = [torch.tensor([float(t)], device=device) for t in timesteps] + [torch.tensor([0.], device=device)]
# sample videos
latent = noise
# injecting motion frames
if not is_first_clip and mode == "multitalk":
latent_motion_frames = latent_motion_frames.to(latent.dtype).to(device)
motion_add_noise = torch.randn(latent_motion_frames.shape, device=torch.device("cpu"), generator=seed_g).to(device).contiguous()
add_latent = add_noise(latent_motion_frames, motion_add_noise, timesteps[0])
latent[:, :add_latent.shape[1]] = add_latent
if offloaded:
# Load weights
if transformer.patched_linear and gguf_reader is None:
load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args)
elif gguf_reader is not None: #handle GGUF
load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args)
#blockswap init
init_blockswap(transformer, block_swap_args, model)
# Use the appropriate prompt for this section
if len(text_embeds["prompt_embeds"]) > 1:
prompt_index = min(iteration_count, len(text_embeds["prompt_embeds"]) - 1)
positive = [text_embeds["prompt_embeds"][prompt_index]]
log.info(f"Using prompt index: {prompt_index}")
else:
positive = text_embeds["prompt_embeds"]
# uni3c slices
if uni3c_embeds is not None:
vae.to(device)
# Pad original_images if needed
num_frames = original_images.shape[2]
if audio_end_idx > num_frames:
pad_len = audio_end_idx - num_frames
last_frame = original_images[:, :, -1:].repeat(1, 1, pad_len, 1, 1)
padded_images = torch.cat([original_images, last_frame], dim=2)
else:
padded_images = original_images
render_latent = vae.encode(
padded_images[:, :, audio_start_idx:audio_end_idx].to(device, vae.dtype),
device=device, tiled=tiled_vae
).to(dtype)
vae.to(offload_device)
uni3c_data['render_latent'] = render_latent
# unianimate slices
partial_unianim_data = None
if unianim_data is not None:
partial_dwpose = dwpose_data[:, :, latent_start_idx:latent_end_idx]
partial_unianim_data = {
"dwpose": partial_dwpose,
"random_ref": unianim_data["random_ref"],
"strength": unianimate_poses["strength"],
"start_percent": unianimate_poses["start_percent"],
"end_percent": unianimate_poses["end_percent"]
}
# fantasy portrait slices
partial_fantasy_portrait_input = None
if fantasy_portrait_input is not None:
adapter_proj = fantasy_portrait_input["adapter_proj"]
if latent_end_idx > adapter_proj.shape[1]:
pad_len = latent_end_idx - adapter_proj.shape[1]
last_frame = adapter_proj[:, -1:, :, :].repeat(1, pad_len, 1, 1)
padded_proj = torch.cat([adapter_proj, last_frame], dim=1)
else:
padded_proj = adapter_proj
partial_fantasy_portrait_input = fantasy_portrait_input.copy()
partial_fantasy_portrait_input["adapter_proj"] = padded_proj[:, latent_start_idx:latent_end_idx]
mm.soft_empty_cache()
gc.collect()
# sampling loop
sampling_pbar = tqdm(total=len(timesteps)-1, desc=f"Sampling audio indices {audio_start_idx}-{audio_end_idx}", position=0, leave=True)
for i in range(len(timesteps)-1):
timestep = timesteps[i]
latent_model_input = latent.to(device)
if mode == "infinitetalk":
if humo_image_cond is None or not is_first_clip:
latent_model_input[:, :cur_motion_frames_latent_num] = latent_motion_frames
noise_pred, _, self.cache_state = predict_func(
latent_model_input, cfg[min(i, len(timesteps)-1)], positive, text_embeds["negative_prompt_embeds"],
timestep, i, y, clip_embeds, control_latents, None, partial_unianim_data, audio_proj, control_camera_latents, add_cond,
cache_state=self.cache_state, multitalk_audio_embeds=audio_embs, fantasy_portrait_input=partial_fantasy_portrait_input,
humo_image_cond=partial_humo_cond_input, humo_image_cond_neg=partial_humo_cond_neg_input, humo_audio=partial_humo_audio, humo_audio_neg=partial_humo_audio_neg,
uni3c_data = uni3c_data)
if callback is not None:
callback_latent = (latent_model_input.to(device) - noise_pred.to(device) * timestep.to(device) / 1000).detach().permute(1,0,2,3)
callback(step_iteration_count, callback_latent, None, estimated_iterations*(len(timesteps)-1))
del callback_latent
sampling_pbar.update(1)
step_iteration_count += 1
# update latent
if use_tsr:
noise_pred = temporal_score_rescaling(noise_pred, latent, timestep, tsr_k, tsr_sigma)
if scheduler == "multitalk":
noise_pred = -noise_pred
dt = (timesteps[i] - timesteps[i + 1]) / 1000
latent = latent + noise_pred * dt[:, None, None, None]
else:
latent = sample_scheduler.step(noise_pred.unsqueeze(0), timestep, latent.unsqueeze(0).to(noise_pred.device), **scheduler_step_args)[0].squeeze(0)
del noise_pred, latent_model_input, timestep
# differential diffusion inpaint
if masks is not None:
if i < len(timesteps) - 1:
image_latent = add_noise(original_image.to(device), noise.to(device), timesteps[i+1])
mask = masks[i].to(latent)
latent = image_latent * mask + latent * (1-mask)
# injecting motion frames
if not is_first_clip and mode == "multitalk":
latent_motion_frames = latent_motion_frames.to(latent.dtype).to(device)
motion_add_noise = torch.randn(latent_motion_frames.shape, device=torch.device("cpu"), generator=seed_g).to(device).contiguous()
add_latent = add_noise(latent_motion_frames, motion_add_noise, timesteps[i+1])
latent[:, :add_latent.shape[1]] = add_latent
else:
if humo_image_cond is None or not is_first_clip:
latent[:, :cur_motion_frames_latent_num] = latent_motion_frames
del noise, latent_motion_frames
if offload:
offload_transformer(transformer, remove_lora=False)
offloaded = True
if humo_image_cond is not None and humo_reference_count > 0:
latent = latent[:,:-humo_reference_count]
vae.to(device)
videos = vae.decode(latent.unsqueeze(0).to(device, vae.dtype), device=device, tiled=tiled_vae, pbar=False)[0].cpu()
vae.to(offload_device)
sampling_pbar.close()
# optional color correction (less relevant for InfiniteTalk)
if colormatch != "disabled":
videos = videos.permute(1, 2, 3, 0).float().numpy()
from color_matcher import ColorMatcher
cm = ColorMatcher()
cm_result_list = []
for img in videos:
if mode == "multitalk":
cm_result = cm.transfer(src=img, ref=original_images[0].permute(1, 2, 3, 0).squeeze(0).cpu().float().numpy(), method=colormatch)
else:
cm_result = cm.transfer(src=img, ref=cond_image[0].permute(1, 2, 3, 0).squeeze(0).cpu().float().numpy(), method=colormatch)
cm_result_list.append(torch.from_numpy(cm_result).to(vae.dtype))
videos = torch.stack(cm_result_list, dim=0).permute(3, 0, 1, 2)
# optionally save generated samples to disk
if output_path:
video_np = videos.clamp(-1.0, 1.0).add(1.0).div(2.0).mul(255).cpu().float().numpy().transpose(1, 2, 3, 0).astype('uint8')
num_frames_to_save = video_np.shape[0] if is_first_clip else video_np.shape[0] - cur_motion_frames_num
log.info(f"Saving {num_frames_to_save} generated frames to {output_path}")
start_idx = 0 if is_first_clip else cur_motion_frames_num
for i in range(start_idx, video_np.shape[0]):
im = Image.fromarray(video_np[i])
im.save(os.path.join(output_path, f"frame_{img_counter:05d}.png"))
img_counter += 1
else:
gen_video_list.append(videos if is_first_clip else videos[:, cur_motion_frames_num:])
current_condframe_index += 1
iteration_count += 1
# decide whether is done
if arrive_last_frame:
break
# update next condition frames
is_first_clip = False
cur_motion_frames_num = motion_frame
cond_ = videos[:, -cur_motion_frames_num:].unsqueeze(0)
if mode == "infinitetalk":
cond_frame = cond_
else:
cond_image = cond_
del videos, latent
# Repeat audio emb
if multitalk_embeds is not None:
audio_start_idx += (frame_num - cur_motion_frames_num - humo_reference_count)
audio_end_idx = audio_start_idx + clip_length
if audio_end_idx >= len(audio_embedding[0]):
arrive_last_frame = True
miss_lengths = []
source_frames = []
for human_inx in range(human_num):
source_frame = len(audio_embedding[human_inx])
source_frames.append(source_frame)
if audio_end_idx >= len(audio_embedding[human_inx]):
log.warning(f"Audio embedding for subject {human_inx} not long enough: {len(audio_embedding[human_inx])}, need {audio_end_idx}, padding...")
miss_length = audio_end_idx - len(audio_embedding[human_inx]) + 3
log.warning(f"Padding length: {miss_length}")
if encoded_silence is not None:
add_audio_emb = encoded_silence[-1*miss_length:]
else:
add_audio_emb = torch.flip(audio_embedding[human_inx][-1*miss_length:], dims=[0])
audio_embedding[human_inx] = torch.cat([audio_embedding[human_inx], add_audio_emb.to(device, dtype)], dim=0)
miss_lengths.append(miss_length)
else:
miss_lengths.append(0)
if mode == "infinitetalk" and current_condframe_index >= original_images.shape[2]:
last_frame = original_images[:, :, -1:, :, :]
miss_length = 1
original_images = torch.cat([original_images, last_frame.repeat(1, 1, miss_length, 1, 1)], dim=2)
if not output_path:
gen_video_samples = torch.cat(gen_video_list, dim=1)
else:
gen_video_samples = torch.zeros(3, 1, 64, 64) # dummy output
if force_offload:
if not model["auto_cpu_offload"]:
offload_transformer(transformer)
try:
print_memory(device)
torch.cuda.reset_peak_memory_stats(device)
except:
pass
return {"video": gen_video_samples.permute(1, 2, 3, 0), "output_path": output_path},

111
nodes.py
View File

@@ -888,6 +888,83 @@ class WanVideoAddMTVMotion:
updated["mtv_crafter_motion"] = new_entry
return (updated,)
class WanVideoAddStoryMemLatents:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"vae": ("WANVAE",),
"embeds": ("WANVIDIMAGE_EMBEDS",),
"memory_images": ("IMAGE",),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
def add(self, vae, embeds, memory_images):
updated = dict(embeds)
story_mem_latents, = WanVideoEncodeLatentBatch().encode(vae, memory_images)
updated["story_mem_latents"] = story_mem_latents["samples"].squeeze(2).permute(1, 0, 2, 3) # [C, T, H, W]
return (updated,)
class WanVideoSVIProEmbeds:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"anchor_samples": ("LATENT", {"tooltip": "Initial start image encoded"}),
"num_frames": ("INT", {"default": 81, "min": 1, "max": 10000, "step": 4, "tooltip": "Number of frames to encode"}),
},
"optional": {
"prev_samples": ("LATENT", {"tooltip": "Last latent from previous generation"}),
"motion_latent_count": ("INT", {"default": 1, "min": 0, "max": 100, "step": 1, "tooltip": "Number of latents used to continue"}),
}
}
RETURN_TYPES = ("WANVIDIMAGE_EMBEDS",)
RETURN_NAMES = ("image_embeds",)
FUNCTION = "add"
CATEGORY = "WanVideoWrapper"
def add(self, anchor_samples, num_frames, prev_samples=None, motion_latent_count=1):
anchor_latent = anchor_samples["samples"][0].clone()
C, T, H, W = anchor_latent.shape
total_latents = (num_frames - 1) // 4 + 1
device = anchor_latent.device
dtype = anchor_latent.dtype
if prev_samples is None or motion_latent_count == 0:
padding_size = total_latents - anchor_latent.shape[1]
padding = torch.zeros(C, padding_size, H, W, dtype=dtype, device=device)
y = torch.concat([anchor_latent, padding], dim=1)
else:
prev_latent = prev_samples["samples"][0].clone()
motion_latent = prev_latent[:, -motion_latent_count:]
padding_size = total_latents - anchor_latent.shape[1] - motion_latent.shape[1]
padding = torch.zeros(C, padding_size, H, W, dtype=dtype, device=device)
y = torch.concat([anchor_latent, motion_latent, padding], dim=1)
msk = torch.ones(1, num_frames, H, W, device=device, dtype=dtype)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, H, W)
msk = msk.transpose(1, 2)[0]
image_embeds = {
"image_embeds": y,
"num_frames": num_frames,
"lat_h": H,
"lat_w": W,
"mask": msk
}
return (image_embeds,)
#region I2V encode
class WanVideoImageToVideoEncode:
@classmethod
@@ -1826,33 +1903,7 @@ class WanVideoContextOptions:
}
return (context_options,)
class WanVideoFlowEdit:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"source_embeds": ("WANVIDEOTEXTEMBEDS", ),
"skip_steps": ("INT", {"default": 4, "min": 0}),
"drift_steps": ("INT", {"default": 0, "min": 0}),
"drift_flow_shift": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 30.0, "step": 0.01}),
"source_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"drift_cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
},
"optional": {
"source_image_embeds": ("WANVIDIMAGE_EMBEDS", ),
}
}
RETURN_TYPES = ("FLOWEDITARGS", )
RETURN_NAMES = ("flowedit_args",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Flowedit options for WanVideo"
def process(self, **kwargs):
return (kwargs,)
class WanVideoLoopArgs:
@classmethod
def INPUT_TYPES(s):
@@ -2122,7 +2173,7 @@ class WanVideoEncodeLatentBatch:
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Encodes a batch of images individually to create a latent video batch where each video is a single frame, useful for I2V init purposes, for example as multiple context window inits"
def encode(self, vae, images, enable_vae_tiling, tile_x, tile_y, tile_stride_x, tile_stride_y, latent_strength=1.0):
def encode(self, vae, images, enable_vae_tiling=False, tile_x=272, tile_y=272, tile_stride_x=144, tile_stride_y=128, latent_strength=1.0):
vae.to(device)
images = images.clone()
@@ -2228,7 +2279,6 @@ NODE_CLASS_MAPPINGS = {
"WanVideoEnhanceAVideo": WanVideoEnhanceAVideo,
"WanVideoContextOptions": WanVideoContextOptions,
"WanVideoTextEmbedBridge": WanVideoTextEmbedBridge,
"WanVideoFlowEdit": WanVideoFlowEdit,
"WanVideoControlEmbeds": WanVideoControlEmbeds,
"WanVideoSLG": WanVideoSLG,
"WanVideoLoopArgs": WanVideoLoopArgs,
@@ -2255,6 +2305,8 @@ NODE_CLASS_MAPPINGS = {
"TextImageEncodeQwenVL": TextImageEncodeQwenVL,
"WanVideoUniLumosEmbeds": WanVideoUniLumosEmbeds,
"WanVideoAddTTMLatents": WanVideoAddTTMLatents,
"WanVideoAddStoryMemLatents": WanVideoAddStoryMemLatents,
"WanVideoSVIProEmbeds": WanVideoSVIProEmbeds,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -2270,7 +2322,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoEnhanceAVideo": "WanVideo Enhance-A-Video",
"WanVideoContextOptions": "WanVideo Context Options",
"WanVideoTextEmbedBridge": "WanVideo TextEmbed Bridge",
"WanVideoFlowEdit": "WanVideo FlowEdit",
"WanVideoControlEmbeds": "WanVideo Control Embeds",
"WanVideoSLG": "WanVideo SLG",
"WanVideoLoopArgs": "WanVideo Loop Args",
@@ -2296,4 +2347,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoAddBindweaveEmbeds": "WanVideo Add Bindweave Embeds",
"WanVideoUniLumosEmbeds": "WanVideo UniLumos Embeds",
"WanVideoAddTTMLatents": "WanVideo Add TTMLatents",
"WanVideoAddStoryMemLatents": "WanVideo Add StoryMem Latents",
"WanVideoSVIProEmbeds": "WanVideo SVIPro Embeds",
}

View File

@@ -36,6 +36,9 @@ try:
except:
PromptServer = None
attention_modes = ["sdpa", "flash_attn_2", "flash_attn_3", "sageattn", "sageattn_3", "radial_sage_attention", "sageattn_compiled",
"sageattn_ultravico", "comfy"]
#from city96's gguf nodes
def update_folder_names_and_paths(key, targets=[]):
# check for existing key
@@ -178,7 +181,7 @@ def standardize_lora_key_format(lora_sd):
new_key += f".{component}"
# Handle weight type - this is the critical fix
# Handle weight type
if weight_type:
if weight_type == 'alpha':
new_key += '.alpha'
@@ -209,12 +212,12 @@ def standardize_lora_key_format(lora_sd):
new_key = new_key.replace('time_embedding', 'time.embedding')
new_key = new_key.replace('time_projection', 'time.projection')
# Replace remaining underscores with dots, carefully
# Replace remaining underscores with dots
parts = new_key.split('.')
final_parts = []
for part in parts:
if part in ['img_emb', 'self_attn', 'cross_attn']:
final_parts.append(part) # Keep these intact
final_parts.append(part)
else:
final_parts.append(part.replace('_', '.'))
new_key = '.'.join(final_parts)
@@ -274,6 +277,20 @@ def standardize_lora_key_format(lora_sd):
new_sd[k] = v
return new_sd
def compensate_rs_lora_format(lora_sd):
rank = lora_sd["base_model.model.blocks.0.cross_attn.k.lora_A.weight"].shape[0]
alpha = torch.tensor(rank * rank // rank ** 0.5)
log.info(f"Detected rank stabilized peft lora format with rank {rank}, setting alpha to {alpha} to compensate.")
new_sd = {}
for k, v in lora_sd.items():
if k.endswith(".lora_A.weight"):
new_sd[k] = v
new_k = k.replace(".lora_A.weight", ".alpha")
new_sd[new_k] = alpha
else:
new_sd[k] = v
return new_sd
class WanVideoBlockSwap:
@classmethod
def INPUT_TYPES(s):
@@ -364,7 +381,7 @@ class WanVideoLoraSelect:
"required": {
"lora": (folder_paths.get_filename_list("loras"),
{"tooltip": "LORA models are expected to be in ComfyUI/models/loras with .safetensors extension"}),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
"strength": ("FLOAT", {"default": 1.0, "min": -1000.0, "max": 1000.0, "step": 0.0001, "tooltip": "LORA strength, set to 0.0 to unmerge the LORA"}),
},
"optional": {
"prev_lora":("WANVIDLORA", {"default": None, "tooltip": "For loading multiple LoRAs"}),
@@ -756,6 +773,8 @@ class WanVideoSetLoRAs:
lora_sd = load_torch_file(lora_path, safe_load=True)
if "dwpose_embedding.0.weight" in lora_sd: #unianimate
raise NotImplementedError("Unianimate LoRA patching is not implemented in this node.")
if "base_model.model.blocks.0.cross_attn.k.lora_A.weight" in lora_sd: # assume rs_lora
lora_sd = compensate_rs_lora_format(lora_sd)
lora_sd = standardize_lora_key_format(lora_sd)
if l["blocks"]:
@@ -967,7 +986,8 @@ def add_lora_weights(patcher, lora, base_dtype, merge_loras=False):
from .unianimate.nodes import update_transformer
log.info("Unianimate LoRA detected, patching model...")
patcher.model.diffusion_model, unianimate_sd = update_transformer(patcher.model.diffusion_model, lora_sd)
if "base_model.model.blocks.0.cross_attn.k.lora_A.weight" in lora_sd: # assume rs_lora
lora_sd = compensate_rs_lora_format(lora_sd)
lora_sd = standardize_lora_key_format(lora_sd)
if l["blocks"]:
@@ -989,6 +1009,43 @@ def add_lora_weights(patcher, lora, base_dtype, merge_loras=False):
del lora_sd
return patcher, control_lora, unianimate_sd
class WanVideoSetAttentionModeOverride:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("WANVIDEOMODEL", ),
"attention_mode": (attention_modes, {"default": "sdpa"}),
"start_step": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Step to start applying the attention mode override"}),
"end_step": ("INT", {"default": 10000, "min": 1, "max": 10000, "step": 1, "tooltip": "Step to end applying the attention mode override"}),
"verbose": ("BOOLEAN", {"default": False, "tooltip": "Print verbose info about attention mode override during generation"}),
},
"optional": {
"blocks":("INT", {"forceInput": True} ),
}
}
RETURN_TYPES = ("WANVIDEOMODEL",)
RETURN_NAMES = ("model", )
FUNCTION = "getmodelpath"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Override the attention mode for the model for specific step and/or block range"
def getmodelpath(self, model, attention_mode, start_step, end_step, verbose, blocks=None):
model_clone = model.clone()
attention_mode_override = {
"mode": attention_mode,
"start_step": start_step,
"end_step": end_step,
"verbose": verbose,
}
if blocks is not None:
attention_mode_override["blocks"] = blocks
model_clone.model_options['transformer_options']["attention_mode_override"] = attention_mode_override
return (model_clone,)
#region Model loading
class WanVideoModelLoader:
@classmethod
@@ -1003,17 +1060,7 @@ class WanVideoModelLoader:
"load_device": (["main_device", "offload_device"], {"default": "offload_device", "tooltip": "Initial device to load the model to, NOT recommended with the larger models unless you have 48GB+ VRAM"}),
},
"optional": {
"attention_mode": ([
"sdpa",
"flash_attn_2",
"flash_attn_3",
"sageattn",
"sageattn_3",
"radial_sage_attention",
"sageattn_compiled",
"sageattn_ultravico",
"comfy"
], {"default": "sdpa"}),
"attention_mode": (attention_modes, {"default": "sdpa"}),
"compile_args": ("WANCOMPILEARGS", ),
"block_swap_args": ("BLOCKSWAPARGS", ),
"lora": ("WANVIDLORA", {"default": None}),
@@ -1218,9 +1265,7 @@ class WanVideoModelLoader:
lynx_ip_layers = "lite"
model_type = "t2v"
if "audio_injector.injector.0.k.weight" in sd:
model_type = "s2v"
elif not "text_embedding.0.weight" in sd:
if not "text_embedding.0.weight" in sd:
model_type = "no_cross_attn" #minimaxremover
elif "model_type.Wan2_1-FLF2V-14B-720P" in sd or "img_emb.emb_pos" in sd or "flf2v" in model.lower():
model_type = "fl2v"
@@ -1230,6 +1275,8 @@ class WanVideoModelLoader:
model_type = "t2v"
elif "control_adapter.conv.weight" in sd:
model_type = "t2v"
if "audio_injector.injector.0.k.weight" in sd:
model_type = "s2v"
out_dim = 16
if dim == 5120: #14B
@@ -2044,6 +2091,7 @@ NODE_CLASS_MAPPINGS = {
"WanVideoTorchCompileSettings": WanVideoTorchCompileSettings,
"LoadWanVideoT5TextEncoder": LoadWanVideoT5TextEncoder,
"LoadWanVideoClipTextEncoder": LoadWanVideoClipTextEncoder,
"WanVideoSetAttentionModeOverride": WanVideoSetAttentionModeOverride,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@@ -2062,4 +2110,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoTorchCompileSettings": "WanVideo Torch Compile Settings",
"LoadWanVideoT5TextEncoder": "WanVideo T5 Text Encoder Loader",
"LoadWanVideoClipTextEncoder": "WanVideo CLIP Text Encoder Loader",
"WanVideoSetAttentionModeOverride": "WanVideo Set Attention Mode Override",
}

File diff suppressed because it is too large Load Diff

View File

@@ -38,7 +38,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, current_flag,
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
window_th = 1560 * 21 / 2
window_th = frame_tokens * window_width / 2
dist2 = tl.abs(m - n).to(tl.int32)
dist_mask = dist2 <= window_th
@@ -46,7 +46,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, current_flag,
qk = tl.where(dist_mask | negative_mask, qk, qk*multi_factor)
window3 = (m <= frame_tokens) & (n > 21*frame_tokens)
window3 = (m <= frame_tokens) & (n > window_width*frame_tokens)
qk = tl.where(window3, -1e4, qk)

View File

@@ -4,17 +4,99 @@ import logging
import math
from tqdm import tqdm
from pathlib import Path
import os
import gc
import types, collections
from comfy.utils import ProgressBar, copy_to_param, set_attr_param
from comfy.model_patcher import get_key_weight, string_to_seed
from comfy.lora import calculate_weight
from comfy.model_management import cast_to_device
from comfy.float import stochastic_rounding
from .custom_linear import remove_lora_from_module
import folder_paths
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
import comfy.model_management as mm
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
try:
from .gguf.gguf import GGUFParameter
except:
pass
class MetaParameter(torch.nn.Parameter):
def __new__(cls, dtype, quant_type=None):
data = torch.empty(0, dtype=dtype)
self = torch.nn.Parameter(data, requires_grad=False)
self.quant_type = quant_type
return self
def offload_transformer(transformer, remove_lora=True):
transformer.teacache_state.clear_all()
transformer.magcache_state.clear_all()
transformer.easycache_state.clear_all()
if transformer.patched_linear:
for name, param in transformer.named_parameters():
if "loras" in name or "controlnet" in name:
continue
module = transformer
subnames = name.split('.')
for subname in subnames[:-1]:
module = getattr(module, subname)
attr_name = subnames[-1]
if param.data.is_floating_point():
meta_param = torch.nn.Parameter(torch.empty_like(param.data, device='meta'), requires_grad=False)
setattr(module, attr_name, meta_param)
elif isinstance(param.data, GGUFParameter):
quant_type = getattr(param, 'quant_type', None)
setattr(module, attr_name, MetaParameter(param.data.dtype, quant_type))
else:
pass
if remove_lora:
remove_lora_from_module(transformer)
else:
transformer.to(offload_device)
for block in transformer.blocks:
block.kv_cache = None
if transformer.audio_model is not None and hasattr(block, 'audio_block'):
block.audio_block = None
mm.soft_empty_cache()
gc.collect()
def init_blockswap(transformer, block_swap_args, model):
if not transformer.patched_linear:
if block_swap_args is not None:
for name, param in transformer.named_parameters():
if "block" not in name or "control_adapter" in name or "face" in name:
param.data = param.data.to(device)
elif block_swap_args["offload_txt_emb"] and "txt_emb" in name:
param.data = param.data.to(offload_device)
elif block_swap_args["offload_img_emb"] and "img_emb" in name:
param.data = param.data.to(offload_device)
transformer.block_swap(
block_swap_args["blocks_to_swap"] - 1 ,
block_swap_args["offload_txt_emb"],
block_swap_args["offload_img_emb"],
vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None),
)
elif model["auto_cpu_offload"]:
for module in transformer.modules():
if hasattr(module, "offload"):
module.offload()
if hasattr(module, "onload"):
module.onload()
for block in transformer.blocks:
block.modulation = torch.nn.Parameter(block.modulation.to(device))
transformer.head.modulation = torch.nn.Parameter(transformer.head.modulation.to(device))
else:
transformer.to(device)
def check_device_same(first_device, second_device):
if first_device.type != second_device.type:
return False
@@ -140,7 +222,7 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False, back
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
if device_to is not None:
temp_weight = cast_to_device(weight, device_to, torch.float32, copy=True)
temp_weight = mm.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
if convert_func is not None:

View File

@@ -80,9 +80,9 @@ except:
try:
from ...ultravico.sageattn.core import sage_attention as sageattn_ultravico
@torch.library.custom_op("wanvideo::sageattn_ultravico", mutates_args=())
def sageattn_func_ultravico(qkv: List[torch.Tensor], attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, multi_factor: float = 0.9
def sageattn_func_ultravico(qkv: List[torch.Tensor], attn_mask: torch.Tensor | None = None, dropout_p: float = 0.0, is_causal: bool = False, multi_factor: float = 0.9, frame_tokens: int = 1536
) -> torch.Tensor:
return sageattn_ultravico(qkv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, multi_factor=multi_factor)
return sageattn_ultravico(qkv, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, multi_factor=multi_factor, frame_tokens=frame_tokens)
@sageattn_func_ultravico.register_fake
def _(qkv, attn_mask=None, dropout_p=0.0, is_causal=False, multi_factor=0.9):
@@ -94,7 +94,7 @@ except:
def attention(q, k, v, q_lens=None, k_lens=None, max_seqlen_q=None, max_seqlen_k=None, dropout_p=0.,
softmax_scale=None, q_scale=None, causal=False, window_size=(-1, -1), deterministic=False, dtype=torch.bfloat16,
attention_mode='sdpa', attn_mask=None, multi_factor=0.9, heads=128):
attention_mode='sdpa', attn_mask=None, multi_factor=0.9, frame_tokens=1536, heads=128):
if "flash" in attention_mode:
return flash_attention(q, k, v, q_lens=q_lens, k_lens=k_lens, dropout_p=dropout_p, softmax_scale=softmax_scale,
q_scale=q_scale, causal=causal, window_size=window_size, deterministic=deterministic, dtype=dtype, version=2 if attention_mode == 'flash_attn_2' else 3,
@@ -108,7 +108,7 @@ def attention(q, k, v, q_lens=None, k_lens=None, max_seqlen_q=None, max_seqlen_k
elif attention_mode == 'sageattn':
return sageattn_func(q, k, v, tensor_layout="NHD").contiguous()
elif attention_mode == 'sageattn_ultravico':
return sageattn_func_ultravico([q, k, v], multi_factor=multi_factor).contiguous()
return sageattn_func_ultravico([q, k, v], multi_factor=multi_factor, frame_tokens=frame_tokens).contiguous()
elif attention_mode == 'comfy':
return optimized_attention(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), heads=heads, skip_reshape=True)
else: # sdpa

View File

@@ -467,7 +467,7 @@ class WanSelfAttention(nn.Module):
v = (self.v(x) + self.v_loras(x)).view(b, s, n, d)
return q, k, v
def forward(self, q, k, v, seq_lens, lynx_ref_feature=None, lynx_ref_scale=1.0, attention_mode_override=None, onetoall_ref=None, onetoall_ref_scale=1.0):
def forward(self, q, k, v, seq_lens, lynx_ref_feature=None, lynx_ref_scale=1.0, attention_mode_override=None, onetoall_ref=None, onetoall_ref_scale=1.0, frame_tokens=1536):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@@ -482,7 +482,7 @@ class WanSelfAttention(nn.Module):
if self.ref_adapter is not None and lynx_ref_feature is not None:
ref_x = self.ref_adapter(self, q, lynx_ref_feature)
x = attention(q, k, v, k_lens=seq_lens, attention_mode=attention_mode, heads=self.num_heads)
x = attention(q, k, v, k_lens=seq_lens, attention_mode=attention_mode, heads=self.num_heads, frame_tokens=frame_tokens)
if self.ref_adapter is not None and lynx_ref_feature is not None:
x = x.add(ref_x, alpha=lynx_ref_scale)
@@ -497,7 +497,7 @@ class WanSelfAttention(nn.Module):
attention_mode = self.attention_mode
if attention_mode_override is not None:
attention_mode = attention_mode_override
# Concatenate main and IP keys/values for main attention
full_k = torch.cat([k, k_ip], dim=1)
full_v = torch.cat([v, v_ip], dim=1)
@@ -1006,6 +1006,7 @@ class WanAttentionBlock(nn.Module):
longcat_num_cond_latents=0, longcat_avatar_options=None, #longcat image cond amount
x_onetoall_ref=None, onetoall_freqs=None, onetoall_ref=None, onetoall_ref_scale=1.0, #one-to-all
e_tr=None, tr_num=0, tr_start=0, #token replacement
attention_mode_override=None, frame_tokens=None,
):
r"""
Args:
@@ -1150,6 +1151,10 @@ class WanAttentionBlock(nn.Module):
if enhance_enabled:
feta_scores = get_feta_scores(q, k)
if self.attention_mode == "sageattn_3" and attention_mode_override is None:
if current_step != 0 and not last_step:
attention_mode_override = "sageattn"
#self-attention
split_attn = (context is not None
and (context.shape[0] > 1 or (clip_embed is not None and clip_embed.shape[0] > 1))
@@ -1161,19 +1166,14 @@ class WanAttentionBlock(nn.Module):
y = self.self_attn.forward_split(q, k, v, seq_lens, grid_sizes, seq_chunks)
elif ref_target_masks is not None: #multi/infinite talk
y, x_ref_attn_map = self.self_attn.forward_multitalk(q, k, v, seq_lens, grid_sizes, ref_target_masks)
elif self.attention_mode == "radial_sage_attention":
elif self.attention_mode == "radial_sage_attention" or attention_mode_override is not None and attention_mode_override == "radial_sage_attention":
if self.dense_block or self.dense_timesteps is not None and current_step < self.dense_timesteps:
if self.dense_attention_mode == "sparse_sage_attn":
y = self.self_attn.forward_radial(q, k, v, dense_step=True)
else:
y = self.self_attn.forward(q, k, v, seq_lens)
y = self.self_attn.forward(q, k, v, seq_lens, attention_mode_override=attention_mode_override)
else:
y = self.self_attn.forward_radial(q, k, v, dense_step=False)
elif self.attention_mode == "sageattn_3":
if current_step != 0 and not last_step:
y = self.self_attn.forward(q, k, v, seq_lens, attention_mode_override="sageattn_3")
else:
y = self.self_attn.forward(q, k, v, seq_lens, attention_mode_override="sageattn")
elif x_ip is not None and self.kv_cache is None: #stand-in
# First pass: cache IP keys/values and compute attention
self.kv_cache = {"k_ip": k_ip.detach(), "v_ip": v_ip.detach()}
@@ -1184,18 +1184,18 @@ class WanAttentionBlock(nn.Module):
v_ip = self.kv_cache["v_ip"]
full_k = torch.cat([k, k_ip], dim=1)
full_v = torch.cat([v, v_ip], dim=1)
y = self.self_attn.forward(q, full_k, full_v, seq_lens)
y = self.self_attn.forward(q, full_k, full_v, seq_lens, attention_mode_override=attention_mode_override)
elif is_longcat and longcat_num_cond_latents > 0:
if longcat_num_cond_latents == 1:
num_cond_latents_thw = longcat_num_cond_latents * (N // num_latent_frames)
# process the noise tokens
x_noise = self.self_attn.forward(q[:, num_cond_latents_thw:].contiguous(), k, v, seq_lens)
x_noise = self.self_attn.forward(q[:, num_cond_latents_thw:].contiguous(), k, v, seq_lens, attention_mode_override=attention_mode_override)
# process the condition tokens
x_cond = self.self_attn.forward(
q[:, :num_cond_latents_thw].contiguous(),
k[:, :num_cond_latents_thw].contiguous(),
v[:, :num_cond_latents_thw].contiguous(),
seq_lens)
seq_lens, attention_mode_override=attention_mode_override)
# merge x_cond and x_noise
y = torch.cat([x_cond, x_noise], dim=1).contiguous()
elif longcat_num_cond_latents > 1: # video continuation
@@ -1237,13 +1237,14 @@ class WanAttentionBlock(nn.Module):
q_cond = q[:, num_ref_latents_thw:num_cond_latents_thw].contiguous()
k_cond = k[:, num_ref_latents_thw:num_cond_latents_thw].contiguous()
v_cond = v[:, num_ref_latents_thw:num_cond_latents_thw].contiguous()
x_ref = self.self_attn.forward(q_ref, k_ref, v_ref, seq_lens)
x_cond = self.self_attn.forward(q_cond, k_cond, v_cond, seq_lens)
x_ref = self.self_attn.forward(q_ref, k_ref, v_ref, seq_lens, attention_mode_override=attention_mode_override)
x_cond = self.self_attn.forward(q_cond, k_cond, v_cond, seq_lens, attention_mode_override=attention_mode_override)
# merge x_cond and x_noise
y = torch.cat([x_ref, x_cond, x_noise], dim=1).contiguous()
else:
y = self.self_attn.forward(q, k, v, seq_lens, lynx_ref_feature=lynx_ref_feature, lynx_ref_scale=lynx_ref_scale, onetoall_ref=onetoall_ref, onetoall_ref_scale=onetoall_ref_scale)
y = self.self_attn.forward(q, k, v, seq_lens, lynx_ref_feature=lynx_ref_feature, lynx_ref_scale=lynx_ref_scale,
onetoall_ref=onetoall_ref, onetoall_ref_scale=onetoall_ref_scale, attention_mode_override=attention_mode_override, frame_tokens=frame_tokens)
del q, k, v
@@ -2187,7 +2188,7 @@ class WanModel(torch.nn.Module):
def rope_encode_comfy(self, t, h, w, freq_offset=0, t_start=0, ref_frame_shape=None, pose_frame_shape=None,
steps_t=None, steps_h=None, steps_w=None, ntk_alphas=[1,1,1], device=None, dtype=None,
ref_frame_index=10, longcat_num_ref_latents=None):
ref_frame_index=10, longcat_num_ref_latents=0):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@@ -2280,6 +2281,7 @@ class WanModel(torch.nn.Module):
self, x, t, context, seq_len,
is_uncond=False,
current_step_percentage=0.0, current_step=0, last_step=0, total_steps=50,
attention_mode_override=None,
clip_fea=None, y=None,
device=torch.device('cuda'),
freqs=None,
@@ -2664,7 +2666,7 @@ class WanModel(torch.nn.Module):
device=x.device,
dtype=x.dtype
)
log.info("Generated new RoPE frequencies")
tqdm.write("Generated new RoPE frequencies")
if s2v_ref_latent is not None:
freqs_ref = self.rope_encode_comfy(
@@ -3093,6 +3095,7 @@ class WanModel(torch.nn.Module):
camera_embed=camera_embed,
audio_proj=audio_proj,
num_latent_frames = F,
frame_tokens=x.shape[1] // F,
original_seq_len=self.original_seq_len,
enhance_enabled=enhance_enabled,
audio_scale=audio_scale,
@@ -3179,8 +3182,21 @@ class WanModel(torch.nn.Module):
if lynx_ref_buffer is None and lynx_ref_feature_extractor:
lynx_ref_buffer = {}
attn_override_blocks = attention_mode = None
attention_mode_override_active = False
if attention_mode_override is not None:
attn_override_blocks = attention_mode_override.get("blocks", range(len(self.blocks)))
if attention_mode_override["start_step"] <= current_step < attention_mode_override["end_step"]:
attention_mode_override_active = True
if attention_mode_override["verbose"]:
tqdm.write(f"Applying attention mode override: {attention_mode_override['mode']} at step {current_step} on blocks: {attn_override_blocks if attn_override_blocks is not None else 'all'}")
for b, block in enumerate(self.blocks):
mm.throw_exception_if_processing_interrupted()
if attention_mode_override_active and b in attn_override_blocks:
attention_mode = attention_mode_override['mode']
else:
attention_mode = None
block_idx = f"{b:02d}"
if lynx_ref_buffer is not None and not lynx_ref_feature_extractor:
lynx_ref_feature = lynx_ref_buffer.get(block_idx, None)
@@ -3224,7 +3240,7 @@ class WanModel(torch.nn.Module):
x_onetoall_ref = onetoall_ref_block_samples[b // interval_ref]
# ---run block----#
x, x_ip, lynx_ref_feature, x_ovi = block(x, x_ip=x_ip, lynx_ref_feature=lynx_ref_feature, x_ovi=x_ovi, x_onetoall_ref=x_onetoall_ref, onetoall_freqs=onetoall_freqs, **kwargs)
x, x_ip, lynx_ref_feature, x_ovi = block(x, x_ip=x_ip, lynx_ref_feature=lynx_ref_feature, x_ovi=x_ovi, x_onetoall_ref=x_onetoall_ref, onetoall_freqs=onetoall_freqs, attention_mode_override=attention_mode, **kwargs)
# ---post block----#
# dual controlnet

View File

@@ -42,7 +42,7 @@ def _apply_custom_sigmas(sample_scheduler, sigmas, device):
sample_scheduler.timesteps = (sample_scheduler.sigmas[:-1] * 1000).to(torch.int64).to(device)
sample_scheduler.num_inference_steps = len(sample_scheduler.timesteps)
def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transformer_dim=5120, flowedit_args=None, denoise_strength=1.0, sigmas=None, log_timesteps=False, enhance_hf=False, **kwargs):
def get_scheduler(scheduler, steps, start_step, end_step, shift, device, transformer_dim=5120, denoise_strength=1.0, sigmas=None, log_timesteps=False, enhance_hf=False, **kwargs):
timesteps = None
if sigmas is not None:
steps = len(sigmas) - 1