1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00

Allow WanMove to work with context windows

This commit is contained in:
kijai
2025-12-15 01:22:21 +02:00
parent 3611341339
commit 3aae54f220

View File

@@ -16,6 +16,7 @@ from .utils import(log, print_memory, apply_lora, fourier_filter, optimized_scal
from .cache_methods.cache_methods import cache_report
from .nodes_model_loading import load_weights
from .enhance_a_video.globals import set_enhance_weight, set_num_frames
from .WanMove.trajectory import replace_feature
from contextlib import nullcontext
from comfy import model_management as mm
@@ -301,15 +302,6 @@ class WanVideoSampler:
#I2V
image_cond = image_embeds.get("image_embeds", None)
if image_cond is not None:
# WanMove
wanmove_embeds = image_embeds.get("wanmove_embeds", None)
if wanmove_embeds is not None:
from .WanMove.trajectory import replace_feature
track_pos = wanmove_embeds["track_pos"]
if any(not math.isclose(c, 1.0) for c in cfg):
image_cond_neg = torch.cat([image_embeds["mask"], image_cond])
image_cond = replace_feature(image_cond.unsqueeze(0).clone(), track_pos.unsqueeze(0), wanmove_embeds.get("strength", 1.0))[0]
if transformer.in_dim == 16:
raise ValueError("T2V (text to video) model detected, encoded images only work with I2V (Image to video) models")
elif transformer.in_dim not in [48, 32]: # fun 2.1 models don't use the mask
@@ -1227,6 +1219,16 @@ class WanVideoSampler:
latents_to_not_step = prev_latents.shape[1]
one_to_all_data["num_latent_frames_to_replace"] = latents_to_not_step
# WanMove
if image_cond is not None:
wanmove_embeds = image_embeds.get("wanmove_embeds", None)
if wanmove_embeds is not None:
track_pos = wanmove_embeds["track_pos"]
if any(not math.isclose(c, 1.0) for c in cfg):
image_cond_neg = torch.cat([image_embeds["mask"], image_cond])
if context_options is None:
image_cond = replace_feature(image_cond.unsqueeze(0).clone(), track_pos.unsqueeze(0), wanmove_embeds.get("strength", 1.0))[0]
#region model pred
def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None,
control_latents=None, vace_data=None, unianim_data=None, audio_proj=None, control_camera_latents=None,
@@ -1465,6 +1467,9 @@ class WanVideoSampler:
if background_latents is not None or foreground_latents is not None:
z = torch.cat([z, foreground_latents.to(z), background_latents.to(z)], dim=0)
if wanmove_embeds is not None and context_window is not None:
image_cond_input = replace_feature(image_cond_input.unsqueeze(0), track_pos[:, context_window].unsqueeze(0), wanmove_embeds.get("strength", 1.0))[0]
base_params = {
'x': [z], # latent
'y': [image_cond_input] if image_cond_input is not None else None, # image cond