From 3aae54f220dc172a7efcb108c293dd68ef960d16 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 15 Dec 2025 01:22:21 +0200 Subject: [PATCH] Allow WanMove to work with context windows --- nodes_sampler.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/nodes_sampler.py b/nodes_sampler.py index 0329db3..a0cd623 100644 --- a/nodes_sampler.py +++ b/nodes_sampler.py @@ -16,6 +16,7 @@ from .utils import(log, print_memory, apply_lora, fourier_filter, optimized_scal from .cache_methods.cache_methods import cache_report from .nodes_model_loading import load_weights from .enhance_a_video.globals import set_enhance_weight, set_num_frames +from .WanMove.trajectory import replace_feature from contextlib import nullcontext from comfy import model_management as mm @@ -301,15 +302,6 @@ class WanVideoSampler: #I2V image_cond = image_embeds.get("image_embeds", None) if image_cond is not None: - # WanMove - wanmove_embeds = image_embeds.get("wanmove_embeds", None) - if wanmove_embeds is not None: - from .WanMove.trajectory import replace_feature - track_pos = wanmove_embeds["track_pos"] - if any(not math.isclose(c, 1.0) for c in cfg): - image_cond_neg = torch.cat([image_embeds["mask"], image_cond]) - image_cond = replace_feature(image_cond.unsqueeze(0).clone(), track_pos.unsqueeze(0), wanmove_embeds.get("strength", 1.0))[0] - if transformer.in_dim == 16: raise ValueError("T2V (text to video) model detected, encoded images only work with I2V (Image to video) models") elif transformer.in_dim not in [48, 32]: # fun 2.1 models don't use the mask @@ -1227,6 +1219,16 @@ class WanVideoSampler: latents_to_not_step = prev_latents.shape[1] one_to_all_data["num_latent_frames_to_replace"] = latents_to_not_step + # WanMove + if image_cond is not None: + wanmove_embeds = image_embeds.get("wanmove_embeds", None) + if wanmove_embeds is not None: + track_pos = wanmove_embeds["track_pos"] + if any(not math.isclose(c, 1.0) for c in cfg): + image_cond_neg = torch.cat([image_embeds["mask"], image_cond]) + if context_options is None: + image_cond = replace_feature(image_cond.unsqueeze(0).clone(), track_pos.unsqueeze(0), wanmove_embeds.get("strength", 1.0))[0] + #region model pred def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None, control_latents=None, vace_data=None, unianim_data=None, audio_proj=None, control_camera_latents=None, @@ -1465,6 +1467,9 @@ class WanVideoSampler: if background_latents is not None or foreground_latents is not None: z = torch.cat([z, foreground_latents.to(z), background_latents.to(z)], dim=0) + if wanmove_embeds is not None and context_window is not None: + image_cond_input = replace_feature(image_cond_input.unsqueeze(0), track_pos[:, context_window].unsqueeze(0), wanmove_embeds.get("strength", 1.0))[0] + base_params = { 'x': [z], # latent 'y': [image_cond_input] if image_cond_input is not None else None, # image cond