You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
Allow WanMove to work with context windows
This commit is contained in:
@@ -16,6 +16,7 @@ from .utils import(log, print_memory, apply_lora, fourier_filter, optimized_scal
|
||||
from .cache_methods.cache_methods import cache_report
|
||||
from .nodes_model_loading import load_weights
|
||||
from .enhance_a_video.globals import set_enhance_weight, set_num_frames
|
||||
from .WanMove.trajectory import replace_feature
|
||||
from contextlib import nullcontext
|
||||
|
||||
from comfy import model_management as mm
|
||||
@@ -301,15 +302,6 @@ class WanVideoSampler:
|
||||
#I2V
|
||||
image_cond = image_embeds.get("image_embeds", None)
|
||||
if image_cond is not None:
|
||||
# WanMove
|
||||
wanmove_embeds = image_embeds.get("wanmove_embeds", None)
|
||||
if wanmove_embeds is not None:
|
||||
from .WanMove.trajectory import replace_feature
|
||||
track_pos = wanmove_embeds["track_pos"]
|
||||
if any(not math.isclose(c, 1.0) for c in cfg):
|
||||
image_cond_neg = torch.cat([image_embeds["mask"], image_cond])
|
||||
image_cond = replace_feature(image_cond.unsqueeze(0).clone(), track_pos.unsqueeze(0), wanmove_embeds.get("strength", 1.0))[0]
|
||||
|
||||
if transformer.in_dim == 16:
|
||||
raise ValueError("T2V (text to video) model detected, encoded images only work with I2V (Image to video) models")
|
||||
elif transformer.in_dim not in [48, 32]: # fun 2.1 models don't use the mask
|
||||
@@ -1227,6 +1219,16 @@ class WanVideoSampler:
|
||||
latents_to_not_step = prev_latents.shape[1]
|
||||
one_to_all_data["num_latent_frames_to_replace"] = latents_to_not_step
|
||||
|
||||
# WanMove
|
||||
if image_cond is not None:
|
||||
wanmove_embeds = image_embeds.get("wanmove_embeds", None)
|
||||
if wanmove_embeds is not None:
|
||||
track_pos = wanmove_embeds["track_pos"]
|
||||
if any(not math.isclose(c, 1.0) for c in cfg):
|
||||
image_cond_neg = torch.cat([image_embeds["mask"], image_cond])
|
||||
if context_options is None:
|
||||
image_cond = replace_feature(image_cond.unsqueeze(0).clone(), track_pos.unsqueeze(0), wanmove_embeds.get("strength", 1.0))[0]
|
||||
|
||||
#region model pred
|
||||
def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, idx, image_cond=None, clip_fea=None,
|
||||
control_latents=None, vace_data=None, unianim_data=None, audio_proj=None, control_camera_latents=None,
|
||||
@@ -1465,6 +1467,9 @@ class WanVideoSampler:
|
||||
if background_latents is not None or foreground_latents is not None:
|
||||
z = torch.cat([z, foreground_latents.to(z), background_latents.to(z)], dim=0)
|
||||
|
||||
if wanmove_embeds is not None and context_window is not None:
|
||||
image_cond_input = replace_feature(image_cond_input.unsqueeze(0), track_pos[:, context_window].unsqueeze(0), wanmove_embeds.get("strength", 1.0))[0]
|
||||
|
||||
base_params = {
|
||||
'x': [z], # latent
|
||||
'y': [image_cond_input] if image_cond_input is not None else None, # image cond
|
||||
|
||||
Reference in New Issue
Block a user