1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00

Fix prompt splitting

This commit is contained in:
kijai
2025-12-07 18:34:21 +02:00
parent 123c9ca312
commit 2369cdbbe9

View File

@@ -519,21 +519,13 @@ class WanSelfAttention(nn.Module):
x_ref_attn_map = get_attn_map_with_target(q.type_as(x), k.type_as(x), grid_sizes[0], ref_target_masks=ref_target_masks)
return x, x_ref_attn_map
def forward_split(self, q, k, v, seq_lens, grid_sizes, seq_chunks):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
def forward_split(self, q, k, v, seq_lens, grid_sizes, seq_chunks):
# Split by frames if multiple prompts are provided
frames, height, width = grid_sizes[0]
tokens_per_frame = height * width
seq_chunks_tensor = torch.tensor(seq_chunks, device=q.device, dtype=frames.dtype)
actual_chunks = torch.minimum(seq_chunks_tensor, frames)
base_frames_per_chunk = frames // actual_chunks
@@ -561,7 +553,7 @@ class WanSelfAttention(nn.Module):
# output
return self.o(x.flatten(2))
def normalized_attention_guidance(self, b, n, d, q, context, nag_context=None, nag_params={}):
# NAG text attention
context_positive = context
@@ -1171,7 +1163,7 @@ class WanAttentionBlock(nn.Module):
# ReCamMaster
if camera_embed is not None:
y = self.projector(y)
y = self.projector(y)
# Stand-in
if x_ip is not None:
@@ -1283,59 +1275,59 @@ class WanAttentionBlock(nn.Module):
y_ip = self.ffn(torch.addcmul(shift_mlp_ip, self.norm2(x_ip), 1 + scale_mlp_ip))
x_ip = x_ip.addcmul(y_ip, gate_mlp_ip)
return x, x_ip, lynx_ref_feature, x_ovi
@torch.compiler.disable()
def split_cross_attn_ffn(self, x, context, shift_mlp, scale_mlp, gate_mlp, clip_embed=None, grid_sizes=None):
# Get number of prompts
num_prompts = context.shape[0]
num_clip_embeds = 0 if clip_embed is None else clip_embed.shape[0]
num_segments = max(num_prompts, num_clip_embeds)
# Extract spatial dimensions
frames, height, width = grid_sizes[0] # Assuming batch size 1
tokens_per_frame = height * width
# Distribute frames across prompts
frames_per_segment = max(1, frames // num_segments)
# Process each prompt segment
x_combined = torch.zeros_like(x)
for i in range(num_segments):
# Calculate frame boundaries for this segment
start_frame = i * frames_per_segment
end_frame = min((i+1) * frames_per_segment, frames) if i < num_segments-1 else frames
# Convert frame indices to token indices
start_idx = start_frame * tokens_per_frame
end_idx = end_frame * tokens_per_frame
segment_indices = torch.arange(start_idx, end_idx, device=x.device, dtype=torch.long)
# Get prompt segment (cycle through available prompts if needed)
prompt_idx = i % num_prompts
segment_context = context[prompt_idx:prompt_idx+1]
# Handle clip_embed for this segment (cycle through available embeddings)
segment_clip_embed = None
if clip_embed is not None:
clip_idx = i % num_clip_embeds
segment_clip_embed = clip_embed[clip_idx:clip_idx+1]
# Get tensor segment
x_segment = x[:, segment_indices, :].to(self.norm3.weight.dtype)
# Process segment with its prompt and clip embedding
processed_segment = self.cross_attn(self.norm3(x_segment), segment_context, clip_embed=segment_clip_embed)
processed_segment = processed_segment.to(x.dtype)
# Add to combined result
x_combined[:, segment_indices, :] = processed_segment
# Continue with FFN
x = x + x_combined
y = self.ffn_chunked(x, shift_mlp, scale_mlp)
x = x.addcmul(y, gate_mlp)
return x
mod_x = torch.addcmul(shift_mlp, self.norm2(x.to(shift_mlp.dtype)), 1 + scale_mlp)
y = self.ffn_chunked(mod_x, num_chunks=1)
return x.addcmul(y, gate_mlp)
class VaceWanAttentionBlock(WanAttentionBlock):
def __init__(