You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
Fix prompt splitting
This commit is contained in:
@@ -519,21 +519,13 @@ class WanSelfAttention(nn.Module):
|
||||
x_ref_attn_map = get_attn_map_with_target(q.type_as(x), k.type_as(x), grid_sizes[0], ref_target_masks=ref_target_masks)
|
||||
|
||||
return x, x_ref_attn_map
|
||||
|
||||
|
||||
def forward_split(self, q, k, v, seq_lens, grid_sizes, seq_chunks):
|
||||
r"""
|
||||
Args:
|
||||
x(Tensor): Shape [B, L, num_heads, C / num_heads]
|
||||
seq_lens(Tensor): Shape [B]
|
||||
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
|
||||
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
|
||||
"""
|
||||
|
||||
|
||||
def forward_split(self, q, k, v, seq_lens, grid_sizes, seq_chunks):
|
||||
# Split by frames if multiple prompts are provided
|
||||
frames, height, width = grid_sizes[0]
|
||||
tokens_per_frame = height * width
|
||||
|
||||
|
||||
seq_chunks_tensor = torch.tensor(seq_chunks, device=q.device, dtype=frames.dtype)
|
||||
actual_chunks = torch.minimum(seq_chunks_tensor, frames)
|
||||
base_frames_per_chunk = frames // actual_chunks
|
||||
@@ -561,7 +553,7 @@ class WanSelfAttention(nn.Module):
|
||||
|
||||
# output
|
||||
return self.o(x.flatten(2))
|
||||
|
||||
|
||||
def normalized_attention_guidance(self, b, n, d, q, context, nag_context=None, nag_params={}):
|
||||
# NAG text attention
|
||||
context_positive = context
|
||||
@@ -1171,7 +1163,7 @@ class WanAttentionBlock(nn.Module):
|
||||
|
||||
# ReCamMaster
|
||||
if camera_embed is not None:
|
||||
y = self.projector(y)
|
||||
y = self.projector(y)
|
||||
|
||||
# Stand-in
|
||||
if x_ip is not None:
|
||||
@@ -1283,59 +1275,59 @@ class WanAttentionBlock(nn.Module):
|
||||
y_ip = self.ffn(torch.addcmul(shift_mlp_ip, self.norm2(x_ip), 1 + scale_mlp_ip))
|
||||
x_ip = x_ip.addcmul(y_ip, gate_mlp_ip)
|
||||
return x, x_ip, lynx_ref_feature, x_ovi
|
||||
|
||||
@torch.compiler.disable()
|
||||
|
||||
|
||||
def split_cross_attn_ffn(self, x, context, shift_mlp, scale_mlp, gate_mlp, clip_embed=None, grid_sizes=None):
|
||||
# Get number of prompts
|
||||
num_prompts = context.shape[0]
|
||||
num_clip_embeds = 0 if clip_embed is None else clip_embed.shape[0]
|
||||
num_segments = max(num_prompts, num_clip_embeds)
|
||||
|
||||
|
||||
# Extract spatial dimensions
|
||||
frames, height, width = grid_sizes[0] # Assuming batch size 1
|
||||
tokens_per_frame = height * width
|
||||
|
||||
|
||||
# Distribute frames across prompts
|
||||
frames_per_segment = max(1, frames // num_segments)
|
||||
|
||||
|
||||
# Process each prompt segment
|
||||
x_combined = torch.zeros_like(x)
|
||||
|
||||
|
||||
for i in range(num_segments):
|
||||
# Calculate frame boundaries for this segment
|
||||
start_frame = i * frames_per_segment
|
||||
end_frame = min((i+1) * frames_per_segment, frames) if i < num_segments-1 else frames
|
||||
|
||||
|
||||
# Convert frame indices to token indices
|
||||
start_idx = start_frame * tokens_per_frame
|
||||
end_idx = end_frame * tokens_per_frame
|
||||
segment_indices = torch.arange(start_idx, end_idx, device=x.device, dtype=torch.long)
|
||||
|
||||
|
||||
# Get prompt segment (cycle through available prompts if needed)
|
||||
prompt_idx = i % num_prompts
|
||||
segment_context = context[prompt_idx:prompt_idx+1]
|
||||
|
||||
|
||||
# Handle clip_embed for this segment (cycle through available embeddings)
|
||||
segment_clip_embed = None
|
||||
if clip_embed is not None:
|
||||
clip_idx = i % num_clip_embeds
|
||||
segment_clip_embed = clip_embed[clip_idx:clip_idx+1]
|
||||
|
||||
|
||||
# Get tensor segment
|
||||
x_segment = x[:, segment_indices, :].to(self.norm3.weight.dtype)
|
||||
|
||||
|
||||
# Process segment with its prompt and clip embedding
|
||||
processed_segment = self.cross_attn(self.norm3(x_segment), segment_context, clip_embed=segment_clip_embed)
|
||||
processed_segment = processed_segment.to(x.dtype)
|
||||
|
||||
|
||||
# Add to combined result
|
||||
x_combined[:, segment_indices, :] = processed_segment
|
||||
|
||||
|
||||
# Continue with FFN
|
||||
x = x + x_combined
|
||||
y = self.ffn_chunked(x, shift_mlp, scale_mlp)
|
||||
x = x.addcmul(y, gate_mlp)
|
||||
return x
|
||||
mod_x = torch.addcmul(shift_mlp, self.norm2(x.to(shift_mlp.dtype)), 1 + scale_mlp)
|
||||
y = self.ffn_chunked(mod_x, num_chunks=1)
|
||||
return x.addcmul(y, gate_mlp)
|
||||
|
||||
class VaceWanAttentionBlock(WanAttentionBlock):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user