1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-28 12:20:55 +03:00

cleanup and use fp32 text/time embed if available

This commit is contained in:
kijai
2025-09-30 18:10:18 +03:00
parent 4f42dbfacf
commit 1ba1a1662b
4 changed files with 47 additions and 52 deletions

View File

@@ -1011,7 +1011,7 @@ class WanVideoSampler:
lynx_ref_latent_uncond = lynx_ref_latent_uncond[0]
lynx_embeds["ref_feature_extractor"] = True
log.info(f"Lynx ref latent shape: {lynx_ref_latent.shape}")
log.info("Extracting Lynx ref buffer...")
log.info("Extracting Lynx ref cond buffer...")
lynx_ref_buffer = transformer(
[lynx_ref_latent.to(device, dtype)],
torch.tensor([0], device=device),
@@ -1019,7 +1019,9 @@ class WanVideoSampler:
seq_len=math.ceil((lynx_ref_latent.shape[2] * lynx_ref_latent.shape[3]) / 4 * lynx_ref_latent.shape[1]),
lynx_embeds=lynx_embeds
)
log.info(f"Extracted {len(lynx_ref_buffer)} cond ref buffers")
if not math.isclose(cfg[0], 1.0):
log.info("Extracting Lynx ref uncond buffer...")
lynx_ref_buffer_uncond = transformer(
[lynx_ref_latent_uncond.to(device, dtype)],
torch.tensor([0], device=device),
@@ -1028,8 +1030,8 @@ class WanVideoSampler:
lynx_embeds=lynx_embeds,
is_uncond=True
)
log.info(f"Extracted {len(lynx_ref_buffer_uncond)} uncond ref buffers")
log.info(f"Extracted {len(lynx_ref_buffer)} ref buffers")
lynx_embeds["ref_feature_extractor"] = False
lynx_embeds["ref_latent"] = lynx_embeds["ref_text_embed"] = None
lynx_embeds["ref_buffer"] = lynx_ref_buffer
@@ -1044,7 +1046,7 @@ class WanVideoSampler:
humo_image_cond=None, humo_image_cond_neg=None, humo_audio=None, humo_audio_neg=None, wananim_pose_latents=None,
wananim_face_pixels=None, uni3c_data=None,):
nonlocal transformer
z = z.to(dtype)
#z = z.to(dtype)
autocast_enabled = ("fp8" in model["quantization"] and not transformer.patched_linear)
with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype) if autocast_enabled else nullcontext():
@@ -2764,8 +2766,7 @@ class WanVideoSampler:
else:
noise_pred, self.cache_state = predict_with_cfg(
latent_model_input,
cfg[idx], text_embeds["prompt_embeds"],
text_embeds["negative_prompt_embeds"],
cfg[idx], text_embeds["prompt_embeds"], text_embeds["negative_prompt_embeds"],
timestep, idx, image_cond, clip_fea, control_latents, vace_data, unianim_data, audio_proj, control_camera_latents, add_cond,
cache_state=self.cache_state, fantasy_portrait_input=fantasy_portrait_input, multitalk_audio_embeds=multitalk_audio_embeds, mtv_motion_tokens=mtv_motion_tokens, s2v_audio_input=s2v_audio_input,
humo_image_cond=humo_image_cond, humo_image_cond_neg=humo_image_cond_neg, humo_audio=humo_audio, humo_audio_neg=humo_audio_neg,
@@ -2774,8 +2775,7 @@ class WanVideoSampler:
if bidirectional_sampling:
noise_pred_flipped, self.cache_state = predict_with_cfg(
latent_model_input_flipped,
cfg[idx], text_embeds["prompt_embeds"],
text_embeds["negative_prompt_embeds"],
cfg[idx], text_embeds["prompt_embeds"], text_embeds["negative_prompt_embeds"],
timestep, idx, image_cond, clip_fea, control_latents, vace_data, unianim_data, audio_proj, control_camera_latents, add_cond,
cache_state=self.cache_state, fantasy_portrait_input=fantasy_portrait_input, mtv_motion_tokens=mtv_motion_tokens,reverse_time=True)

View File

@@ -35,9 +35,12 @@ except Exception as e:
sageattn_func = None
try:
from sageattn import sageattn_blackwell
from sageattn3 import sageattn3_blackwell as sageattn_blackwell
except:
SAGE3_AVAILABLE = False
try:
from sageattn import sageattn_blackwell
except:
SAGE3_AVAILABLE = False
__all__ = [

View File

@@ -961,39 +961,25 @@ class WanAttentionBlock(nn.Module):
#region attention forward
def forward(
self,
x,
e,
seq_lens,
grid_sizes,
freqs,
context,
current_step,
self, x, e, seq_lens, grid_sizes, freqs, context, current_step,
last_step=False,
video_attention_split_steps=[],
clip_embed=None,
camera_embed=None,
audio_proj=None,
audio_scale=1.0,
camera_embed=None, #ReCamMaster
audio_proj=None, audio_scale=1.0, #fantasytalking
num_latent_frames=21,
original_seq_len=None,
enhance_enabled=False,
nag_params={},
nag_context=None,
enhance_enabled=False, #feta
nag_params={}, nag_context=None, #normalized attention guidance
is_uncond=False,
multitalk_audio_embedding=None,
ref_target_masks=None,
human_num=0,
inner_t=None, inner_c=None,
cross_freqs=None,
x_ip=None, e_ip=None,
freqs_ip=None,
adapter_proj=None,
ip_scale=1.0,
multitalk_audio_embedding=None, ref_target_masks=None, human_num=0, #multitalk
inner_t=None, inner_c=None, cross_freqs=None, #echoshot
x_ip=None, e_ip=None, freqs_ip=None, ip_scale=1.0, #stand-in
adapter_proj=None, #fantasyportrait
reverse_time=False,
mtv_motion_tokens=None, mtv_motion_rotary_emb=None, mtv_strength=1.0, mtv_freqs=None,
humo_audio_input=None, humo_audio_scale=1.0,
lynx_x_ip=None, lynx_ref_feature=None, lynx_ip_scale=1.0, lynx_ref_scale=1.0,
mtv_motion_tokens=None, mtv_motion_rotary_emb=None, mtv_strength=1.0, mtv_freqs=None, #mtv crafter
humo_audio_input=None, humo_audio_scale=1.0, #humo audio
lynx_x_ip=None, lynx_ref_feature=None, lynx_ip_scale=1.0, lynx_ref_scale=1.0, #lynx
):
r"""
Args:
@@ -1652,6 +1638,8 @@ class WanModel(torch.nn.Module):
self.motion_encoder_dim = motion_encoder_dim
self.base_dtype = dtype
# embeddings
self.patch_embedding = nn.Conv3d(
in_dim, dim, kernel_size=patch_size, stride=patch_size)
@@ -2330,7 +2318,7 @@ class WanModel(torch.nn.Module):
if self.zero_timestep:
t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)])
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(x.dtype)) # b, dim
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.time_embedding[0].weight.dtype)) # b, dim
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
#S2V zero timestep
@@ -2346,7 +2334,7 @@ class WanModel(torch.nn.Module):
if x_ip is not None:
timestep_ip = torch.zeros_like(t) # [B] with 0s
t_ip = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep_ip.flatten()).to(x.dtype)) # b, dim )
t_ip = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep_ip.flatten()).to(self.time_embedding[0].weight.dtype)) # b, dim )
e0_ip = self.time_projection(t_ip).unflatten(1, (6, self.dim))
if fps_embeds is not None:
@@ -2387,7 +2375,7 @@ class WanModel(torch.nn.Module):
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]).to(x.dtype))
]).to(self.text_embedding[0].weight.dtype))
# NAG
if nag_context is not None:
@@ -2396,7 +2384,7 @@ class WanModel(torch.nn.Module):
torch.cat(
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in nag_context
]).to(x.dtype))
]).to(self.text_embedding[0].weight.dtype))
if self.offload_txt_emb:
self.text_embedding.to(self.offload_device, non_blocking=self.use_non_blocking)
@@ -2567,6 +2555,8 @@ class WanModel(torch.nn.Module):
else:
should_calc = True
x = x.to(self.base_dtype)
if self.enable_easycache:
original_x = x.clone().to(self.cache_device)
if should_calc:
@@ -2661,12 +2651,11 @@ class WanModel(torch.nn.Module):
# lynx ref
if lynx_ref_buffer is None and lynx_ref_feature_extractor:
lynx_ref_buffer = {}
print("Lynx reference feature extractor enabled.")
for b, block in enumerate(self.blocks):
block_idx = f"{b:02d}"
if lynx_ref_buffer is not None and not lynx_ref_feature_extractor:
print("reading from lynx ref buffer for block", block_idx)
#print("reading from lynx ref buffer for block", block_idx)
lynx_ref_feature = lynx_ref_buffer.get(block_idx, None)
else:
lynx_ref_feature = None

View File

@@ -998,7 +998,7 @@ class VideoVAE_(nn.Module):
return mu
def encode(self, x, pbar=True):
def encode(self, x, pbar=True, sample=False):
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
if pbar:
@@ -1018,11 +1018,15 @@ class VideoVAE_(nn.Module):
if pbar:
pbar.update(1)
self.clear_cache()
mu = self.conv1(out).chunk(2, dim=1)[0]
mu, log_var = self.conv1(out).chunk(2, dim=1)
mu = (mu - self.mean.to(mu)) * self.inv_std.to(mu)
if pbar:
pbar.update_absolute(0)
if sample:
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
eps = torch.randn_like(std)
return mu + std * eps
return mu
@@ -1271,21 +1275,20 @@ class WanVideoVAE(nn.Module):
return values
def single_encode(self, video, device, pbar=True):
def single_encode(self, video, device, pbar=True, sample=False):
video = video.to(device)
x = self.model.encode(video, pbar=pbar)
x = self.model.encode(video, pbar=pbar, sample=sample)
return x.float()
def single_decode(self, hidden_state, device, pbar=True):
hidden_state = hidden_state.to(device)
video = self.model.decode(hidden_state, pbar=pbar)
return video
def double_encode(self, video, device):
def double_encode(self, video, device, pbar=True, sample=False):
print('double_encode')
video = video.to(device)
x = self.model.encode_2(video)
x = self.model.encode_2(video, pbar=pbar, sample=sample)
return x.float()
def double_decode(self, hidden_state, device):
@@ -1294,7 +1297,7 @@ class WanVideoVAE(nn.Module):
video = self.model.decode_2(hidden_state)
return video
def encode(self, videos, device, tiled=False,end_=False, tile_size=None, tile_stride=None, pbar=True):
def encode(self, videos, device, tiled=False,end_=False, tile_size=None, tile_stride=None, pbar=True, sample=False):
self.model.clear_cache()
videos = [video.to("cpu") for video in videos]
hidden_states = []
@@ -1306,7 +1309,7 @@ class WanVideoVAE(nn.Module):
if end_:
hidden_state = self.double_encode(video, device)
else:
hidden_state = self.single_encode(video, device, pbar=pbar)
hidden_state = self.single_encode(video, device, pbar=pbar, sample=sample)
hidden_state = hidden_state.squeeze(0)
hidden_states.append(hidden_state)
hidden_states = torch.stack(hidden_states)
@@ -1366,7 +1369,7 @@ class VideoVAE38_(VideoVAE_):
attn_scales, self.temperal_upsample, dropout)
def encode(self, x, pbar=True):
def encode(self, x, pbar=True, sample=False):
self.clear_cache()
x = patchify(x, patch_size=2)
t = x.shape[2]