You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-28 12:20:55 +03:00
cleanup and use fp32 text/time embed if available
This commit is contained in:
@@ -1011,7 +1011,7 @@ class WanVideoSampler:
|
||||
lynx_ref_latent_uncond = lynx_ref_latent_uncond[0]
|
||||
lynx_embeds["ref_feature_extractor"] = True
|
||||
log.info(f"Lynx ref latent shape: {lynx_ref_latent.shape}")
|
||||
log.info("Extracting Lynx ref buffer...")
|
||||
log.info("Extracting Lynx ref cond buffer...")
|
||||
lynx_ref_buffer = transformer(
|
||||
[lynx_ref_latent.to(device, dtype)],
|
||||
torch.tensor([0], device=device),
|
||||
@@ -1019,7 +1019,9 @@ class WanVideoSampler:
|
||||
seq_len=math.ceil((lynx_ref_latent.shape[2] * lynx_ref_latent.shape[3]) / 4 * lynx_ref_latent.shape[1]),
|
||||
lynx_embeds=lynx_embeds
|
||||
)
|
||||
log.info(f"Extracted {len(lynx_ref_buffer)} cond ref buffers")
|
||||
if not math.isclose(cfg[0], 1.0):
|
||||
log.info("Extracting Lynx ref uncond buffer...")
|
||||
lynx_ref_buffer_uncond = transformer(
|
||||
[lynx_ref_latent_uncond.to(device, dtype)],
|
||||
torch.tensor([0], device=device),
|
||||
@@ -1028,8 +1030,8 @@ class WanVideoSampler:
|
||||
lynx_embeds=lynx_embeds,
|
||||
is_uncond=True
|
||||
)
|
||||
log.info(f"Extracted {len(lynx_ref_buffer_uncond)} uncond ref buffers")
|
||||
|
||||
log.info(f"Extracted {len(lynx_ref_buffer)} ref buffers")
|
||||
lynx_embeds["ref_feature_extractor"] = False
|
||||
lynx_embeds["ref_latent"] = lynx_embeds["ref_text_embed"] = None
|
||||
lynx_embeds["ref_buffer"] = lynx_ref_buffer
|
||||
@@ -1044,7 +1046,7 @@ class WanVideoSampler:
|
||||
humo_image_cond=None, humo_image_cond_neg=None, humo_audio=None, humo_audio_neg=None, wananim_pose_latents=None,
|
||||
wananim_face_pixels=None, uni3c_data=None,):
|
||||
nonlocal transformer
|
||||
z = z.to(dtype)
|
||||
#z = z.to(dtype)
|
||||
autocast_enabled = ("fp8" in model["quantization"] and not transformer.patched_linear)
|
||||
with torch.autocast(device_type=mm.get_autocast_device(device), dtype=dtype) if autocast_enabled else nullcontext():
|
||||
|
||||
@@ -2764,8 +2766,7 @@ class WanVideoSampler:
|
||||
else:
|
||||
noise_pred, self.cache_state = predict_with_cfg(
|
||||
latent_model_input,
|
||||
cfg[idx], text_embeds["prompt_embeds"],
|
||||
text_embeds["negative_prompt_embeds"],
|
||||
cfg[idx], text_embeds["prompt_embeds"], text_embeds["negative_prompt_embeds"],
|
||||
timestep, idx, image_cond, clip_fea, control_latents, vace_data, unianim_data, audio_proj, control_camera_latents, add_cond,
|
||||
cache_state=self.cache_state, fantasy_portrait_input=fantasy_portrait_input, multitalk_audio_embeds=multitalk_audio_embeds, mtv_motion_tokens=mtv_motion_tokens, s2v_audio_input=s2v_audio_input,
|
||||
humo_image_cond=humo_image_cond, humo_image_cond_neg=humo_image_cond_neg, humo_audio=humo_audio, humo_audio_neg=humo_audio_neg,
|
||||
@@ -2774,8 +2775,7 @@ class WanVideoSampler:
|
||||
if bidirectional_sampling:
|
||||
noise_pred_flipped, self.cache_state = predict_with_cfg(
|
||||
latent_model_input_flipped,
|
||||
cfg[idx], text_embeds["prompt_embeds"],
|
||||
text_embeds["negative_prompt_embeds"],
|
||||
cfg[idx], text_embeds["prompt_embeds"], text_embeds["negative_prompt_embeds"],
|
||||
timestep, idx, image_cond, clip_fea, control_latents, vace_data, unianim_data, audio_proj, control_camera_latents, add_cond,
|
||||
cache_state=self.cache_state, fantasy_portrait_input=fantasy_portrait_input, mtv_motion_tokens=mtv_motion_tokens,reverse_time=True)
|
||||
|
||||
|
||||
@@ -35,9 +35,12 @@ except Exception as e:
|
||||
sageattn_func = None
|
||||
|
||||
try:
|
||||
from sageattn import sageattn_blackwell
|
||||
from sageattn3 import sageattn3_blackwell as sageattn_blackwell
|
||||
except:
|
||||
SAGE3_AVAILABLE = False
|
||||
try:
|
||||
from sageattn import sageattn_blackwell
|
||||
except:
|
||||
SAGE3_AVAILABLE = False
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -961,39 +961,25 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
#region attention forward
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
e,
|
||||
seq_lens,
|
||||
grid_sizes,
|
||||
freqs,
|
||||
context,
|
||||
current_step,
|
||||
self, x, e, seq_lens, grid_sizes, freqs, context, current_step,
|
||||
last_step=False,
|
||||
video_attention_split_steps=[],
|
||||
clip_embed=None,
|
||||
camera_embed=None,
|
||||
audio_proj=None,
|
||||
audio_scale=1.0,
|
||||
camera_embed=None, #ReCamMaster
|
||||
audio_proj=None, audio_scale=1.0, #fantasytalking
|
||||
num_latent_frames=21,
|
||||
original_seq_len=None,
|
||||
enhance_enabled=False,
|
||||
nag_params={},
|
||||
nag_context=None,
|
||||
enhance_enabled=False, #feta
|
||||
nag_params={}, nag_context=None, #normalized attention guidance
|
||||
is_uncond=False,
|
||||
multitalk_audio_embedding=None,
|
||||
ref_target_masks=None,
|
||||
human_num=0,
|
||||
inner_t=None, inner_c=None,
|
||||
cross_freqs=None,
|
||||
x_ip=None, e_ip=None,
|
||||
freqs_ip=None,
|
||||
adapter_proj=None,
|
||||
ip_scale=1.0,
|
||||
multitalk_audio_embedding=None, ref_target_masks=None, human_num=0, #multitalk
|
||||
inner_t=None, inner_c=None, cross_freqs=None, #echoshot
|
||||
x_ip=None, e_ip=None, freqs_ip=None, ip_scale=1.0, #stand-in
|
||||
adapter_proj=None, #fantasyportrait
|
||||
reverse_time=False,
|
||||
mtv_motion_tokens=None, mtv_motion_rotary_emb=None, mtv_strength=1.0, mtv_freqs=None,
|
||||
humo_audio_input=None, humo_audio_scale=1.0,
|
||||
lynx_x_ip=None, lynx_ref_feature=None, lynx_ip_scale=1.0, lynx_ref_scale=1.0,
|
||||
mtv_motion_tokens=None, mtv_motion_rotary_emb=None, mtv_strength=1.0, mtv_freqs=None, #mtv crafter
|
||||
humo_audio_input=None, humo_audio_scale=1.0, #humo audio
|
||||
lynx_x_ip=None, lynx_ref_feature=None, lynx_ip_scale=1.0, lynx_ref_scale=1.0, #lynx
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
@@ -1652,6 +1638,8 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
self.motion_encoder_dim = motion_encoder_dim
|
||||
|
||||
self.base_dtype = dtype
|
||||
|
||||
# embeddings
|
||||
self.patch_embedding = nn.Conv3d(
|
||||
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
@@ -2330,7 +2318,7 @@ class WanModel(torch.nn.Module):
|
||||
if self.zero_timestep:
|
||||
t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)])
|
||||
|
||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(x.dtype)) # b, dim
|
||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(self.time_embedding[0].weight.dtype)) # b, dim
|
||||
e0 = self.time_projection(e).unflatten(1, (6, self.dim)) # b, 6, dim
|
||||
|
||||
#S2V zero timestep
|
||||
@@ -2346,7 +2334,7 @@ class WanModel(torch.nn.Module):
|
||||
|
||||
if x_ip is not None:
|
||||
timestep_ip = torch.zeros_like(t) # [B] with 0s
|
||||
t_ip = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep_ip.flatten()).to(x.dtype)) # b, dim )
|
||||
t_ip = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep_ip.flatten()).to(self.time_embedding[0].weight.dtype)) # b, dim )
|
||||
e0_ip = self.time_projection(t_ip).unflatten(1, (6, self.dim))
|
||||
|
||||
if fps_embeds is not None:
|
||||
@@ -2387,7 +2375,7 @@ class WanModel(torch.nn.Module):
|
||||
torch.cat(
|
||||
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||
for u in context
|
||||
]).to(x.dtype))
|
||||
]).to(self.text_embedding[0].weight.dtype))
|
||||
|
||||
# NAG
|
||||
if nag_context is not None:
|
||||
@@ -2396,7 +2384,7 @@ class WanModel(torch.nn.Module):
|
||||
torch.cat(
|
||||
[u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
|
||||
for u in nag_context
|
||||
]).to(x.dtype))
|
||||
]).to(self.text_embedding[0].weight.dtype))
|
||||
|
||||
if self.offload_txt_emb:
|
||||
self.text_embedding.to(self.offload_device, non_blocking=self.use_non_blocking)
|
||||
@@ -2567,6 +2555,8 @@ class WanModel(torch.nn.Module):
|
||||
else:
|
||||
should_calc = True
|
||||
|
||||
x = x.to(self.base_dtype)
|
||||
|
||||
if self.enable_easycache:
|
||||
original_x = x.clone().to(self.cache_device)
|
||||
if should_calc:
|
||||
@@ -2661,12 +2651,11 @@ class WanModel(torch.nn.Module):
|
||||
# lynx ref
|
||||
if lynx_ref_buffer is None and lynx_ref_feature_extractor:
|
||||
lynx_ref_buffer = {}
|
||||
print("Lynx reference feature extractor enabled.")
|
||||
|
||||
for b, block in enumerate(self.blocks):
|
||||
block_idx = f"{b:02d}"
|
||||
if lynx_ref_buffer is not None and not lynx_ref_feature_extractor:
|
||||
print("reading from lynx ref buffer for block", block_idx)
|
||||
#print("reading from lynx ref buffer for block", block_idx)
|
||||
lynx_ref_feature = lynx_ref_buffer.get(block_idx, None)
|
||||
else:
|
||||
lynx_ref_feature = None
|
||||
|
||||
@@ -998,7 +998,7 @@ class VideoVAE_(nn.Module):
|
||||
return mu
|
||||
|
||||
|
||||
def encode(self, x, pbar=True):
|
||||
def encode(self, x, pbar=True, sample=False):
|
||||
t = x.shape[2]
|
||||
iter_ = 1 + (t - 1) // 4
|
||||
if pbar:
|
||||
@@ -1018,11 +1018,15 @@ class VideoVAE_(nn.Module):
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
self.clear_cache()
|
||||
mu = self.conv1(out).chunk(2, dim=1)[0]
|
||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||
mu = (mu - self.mean.to(mu)) * self.inv_std.to(mu)
|
||||
if pbar:
|
||||
pbar.update_absolute(0)
|
||||
|
||||
|
||||
if sample:
|
||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||
eps = torch.randn_like(std)
|
||||
return mu + std * eps
|
||||
return mu
|
||||
|
||||
|
||||
@@ -1271,21 +1275,20 @@ class WanVideoVAE(nn.Module):
|
||||
return values
|
||||
|
||||
|
||||
def single_encode(self, video, device, pbar=True):
|
||||
def single_encode(self, video, device, pbar=True, sample=False):
|
||||
video = video.to(device)
|
||||
x = self.model.encode(video, pbar=pbar)
|
||||
x = self.model.encode(video, pbar=pbar, sample=sample)
|
||||
return x.float()
|
||||
|
||||
|
||||
def single_decode(self, hidden_state, device, pbar=True):
|
||||
hidden_state = hidden_state.to(device)
|
||||
video = self.model.decode(hidden_state, pbar=pbar)
|
||||
return video
|
||||
|
||||
def double_encode(self, video, device):
|
||||
def double_encode(self, video, device, pbar=True, sample=False):
|
||||
print('double_encode')
|
||||
video = video.to(device)
|
||||
x = self.model.encode_2(video)
|
||||
x = self.model.encode_2(video, pbar=pbar, sample=sample)
|
||||
return x.float()
|
||||
|
||||
def double_decode(self, hidden_state, device):
|
||||
@@ -1294,7 +1297,7 @@ class WanVideoVAE(nn.Module):
|
||||
video = self.model.decode_2(hidden_state)
|
||||
return video
|
||||
|
||||
def encode(self, videos, device, tiled=False,end_=False, tile_size=None, tile_stride=None, pbar=True):
|
||||
def encode(self, videos, device, tiled=False,end_=False, tile_size=None, tile_stride=None, pbar=True, sample=False):
|
||||
self.model.clear_cache()
|
||||
videos = [video.to("cpu") for video in videos]
|
||||
hidden_states = []
|
||||
@@ -1306,7 +1309,7 @@ class WanVideoVAE(nn.Module):
|
||||
if end_:
|
||||
hidden_state = self.double_encode(video, device)
|
||||
else:
|
||||
hidden_state = self.single_encode(video, device, pbar=pbar)
|
||||
hidden_state = self.single_encode(video, device, pbar=pbar, sample=sample)
|
||||
hidden_state = hidden_state.squeeze(0)
|
||||
hidden_states.append(hidden_state)
|
||||
hidden_states = torch.stack(hidden_states)
|
||||
@@ -1366,7 +1369,7 @@ class VideoVAE38_(VideoVAE_):
|
||||
attn_scales, self.temperal_upsample, dropout)
|
||||
|
||||
|
||||
def encode(self, x, pbar=True):
|
||||
def encode(self, x, pbar=True, sample=False):
|
||||
self.clear_cache()
|
||||
x = patchify(x, patch_size=2)
|
||||
t = x.shape[2]
|
||||
|
||||
Reference in New Issue
Block a user