You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
439 lines
16 KiB
Python
439 lines
16 KiB
Python
from einops import rearrange, repeat
|
||
import torch
|
||
import torch.nn as nn
|
||
from ..wanvideo.modules.attention import attention
|
||
|
||
def timestep_transform(
|
||
t,
|
||
shift=5.0,
|
||
num_timesteps=1000,
|
||
):
|
||
t = t / num_timesteps
|
||
# shift the timestep based on ratio
|
||
new_t = shift * t / (1 + (shift - 1) * t)
|
||
new_t = new_t * num_timesteps
|
||
return new_t
|
||
|
||
def add_noise(
|
||
original_samples: torch.FloatTensor,
|
||
noise: torch.FloatTensor,
|
||
timesteps: torch.IntTensor,
|
||
) -> torch.FloatTensor:
|
||
"""
|
||
compatible with diffusers add_noise()
|
||
"""
|
||
timesteps = timesteps.float() / 1000
|
||
timesteps = timesteps.view(timesteps.shape + (1,) * (len(noise.shape)-1))
|
||
|
||
return (1 - timesteps) * original_samples + timesteps * noise
|
||
|
||
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
|
||
|
||
source_min, source_max = source_range
|
||
new_min, new_max = target_range
|
||
|
||
normalized = (column - source_min) / (source_max - source_min + epsilon)
|
||
scaled = normalized * (new_max - new_min) + new_min
|
||
return scaled
|
||
|
||
def rotate_half(x):
|
||
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||
x1, x2 = x.unbind(dim=-1)
|
||
x = torch.stack((-x2, x1), dim=-1)
|
||
return rearrange(x, "... d r -> ... (d r)")
|
||
|
||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=4):
|
||
scale = 1.0 / visual_q.shape[-1] ** 0.5
|
||
visual_q = visual_q.transpose(1, 2) * scale
|
||
|
||
B, H, x_seqlens, K = visual_q.shape
|
||
|
||
x_ref_attn_maps = []
|
||
for class_idx, ref_target_mask in enumerate(ref_target_masks):
|
||
ref_target_mask = ref_target_mask.view(1, 1, 1, -1)
|
||
|
||
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
|
||
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
|
||
|
||
for i in range(0, x_seqlens, chunk_size):
|
||
end_i = min(i + chunk_size, x_seqlens)
|
||
|
||
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
|
||
|
||
# Apply softmax
|
||
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
|
||
attn_chunk = (attn_chunk - attn_max).exp()
|
||
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
|
||
attn_chunk = attn_chunk / (attn_sum + 1e-8)
|
||
|
||
# Apply mask and sum
|
||
masked_attn = attn_chunk * ref_target_mask
|
||
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
|
||
|
||
del attn_chunk, masked_attn
|
||
|
||
# Average across heads
|
||
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
|
||
x_ref_attn_maps.append(x_ref_attnmap)
|
||
|
||
del visual_q, ref_k
|
||
|
||
return torch.cat(x_ref_attn_maps, dim=0)
|
||
|
||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
|
||
"""Args:
|
||
query (torch.tensor): B M H K
|
||
key (torch.tensor): B M H K
|
||
shape (tuple): (N_t, N_h, N_w)
|
||
ref_target_masks: [B, N_h * N_w]
|
||
"""
|
||
|
||
N_t, N_h, N_w = shape
|
||
|
||
x_seqlens = N_h * N_w
|
||
ref_k = ref_k[:, :x_seqlens]
|
||
_, seq_lens, heads, _ = visual_q.shape
|
||
class_num, _ = ref_target_masks.shape
|
||
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q.device).to(visual_q.dtype)
|
||
|
||
split_chunk = heads // split_num
|
||
|
||
for i in range(split_num):
|
||
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks)
|
||
x_ref_attn_maps += x_ref_attn_maps_perhead
|
||
|
||
return x_ref_attn_maps / split_num
|
||
|
||
class RotaryPositionalEmbedding1D(nn.Module):
|
||
|
||
def __init__(self,
|
||
head_dim,
|
||
):
|
||
super().__init__()
|
||
self.head_dim = head_dim
|
||
self.base = 10000
|
||
|
||
def precompute_freqs_cis_1d(self, pos_indices):
|
||
|
||
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
|
||
freqs = freqs.to(pos_indices.device)
|
||
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
|
||
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
||
return freqs
|
||
|
||
def forward(self, x, pos_indices):
|
||
"""1D RoPE.
|
||
|
||
Args:
|
||
query (torch.tensor): [B, head, seq, head_dim]
|
||
pos_indices (torch.tensor): [seq,]
|
||
Returns:
|
||
query with the same shape as input.
|
||
"""
|
||
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
|
||
in_dtype = x.dtype
|
||
x = x.float()
|
||
|
||
freqs_cis = freqs_cis.float().to(x.device)
|
||
cos = rearrange(freqs_cis.cos(), 'n d -> 1 1 n d')
|
||
sin = rearrange(freqs_cis.sin(), 'n d -> 1 1 n d')
|
||
|
||
# In-place rotation to save memory
|
||
x_rotated = rotate_half(x)
|
||
x.mul_(cos).add_(x_rotated * sin)
|
||
|
||
return x.to(in_dtype)
|
||
|
||
class AudioProjModel(nn.Module):
|
||
def __init__(
|
||
self,
|
||
seq_len=5,
|
||
seq_len_vf=12,
|
||
blocks=12,
|
||
channels=768,
|
||
intermediate_dim=512,
|
||
output_dim=768,
|
||
context_tokens=32,
|
||
norm_output_audio=False,
|
||
):
|
||
super().__init__()
|
||
|
||
self.seq_len = seq_len
|
||
self.blocks = blocks
|
||
self.channels = channels
|
||
self.input_dim = seq_len * blocks * channels
|
||
self.input_dim_vf = seq_len_vf * blocks * channels
|
||
self.intermediate_dim = intermediate_dim
|
||
self.context_tokens = context_tokens
|
||
self.output_dim = output_dim
|
||
|
||
# define multiple linear layers
|
||
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
|
||
self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
|
||
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
|
||
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
|
||
self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()
|
||
|
||
def forward(self, audio_embeds, audio_embeds_vf):
|
||
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
||
B, _, _, S, C = audio_embeds.shape
|
||
|
||
# process audio of first frame
|
||
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||
|
||
# process audio of latter frame
|
||
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
||
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
||
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
||
|
||
# first projection
|
||
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
||
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
||
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
||
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
||
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
||
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
||
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
||
|
||
# second projection
|
||
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
||
|
||
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim)
|
||
|
||
# normalization and reshape
|
||
context_tokens = self.norm(context_tokens.to(self.norm.weight.dtype)).to(context_tokens.dtype)
|
||
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||
|
||
return context_tokens
|
||
|
||
#@torch.compiler.disable()
|
||
class SingleStreamAttention(nn.Module):
|
||
def __init__(
|
||
self,
|
||
dim: int,
|
||
encoder_hidden_states_dim: int,
|
||
num_heads: int,
|
||
qkv_bias: bool,
|
||
attention_mode: str = 'sdpa',
|
||
) -> None:
|
||
super().__init__()
|
||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||
self.dim = dim
|
||
self.encoder_hidden_states_dim = encoder_hidden_states_dim
|
||
self.num_heads = num_heads
|
||
self.head_dim = dim // num_heads
|
||
self.attention_mode = attention_mode
|
||
|
||
self.q_linear = nn.Linear(dim, dim, bias=qkv_bias)
|
||
self.proj = nn.Linear(dim, dim)
|
||
self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias)
|
||
|
||
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor:
|
||
N_t, N_h, N_w = shape
|
||
|
||
expected_tokens = N_t * N_h * N_w
|
||
actual_tokens = x.shape[1]
|
||
x_extra = None
|
||
|
||
if actual_tokens != expected_tokens:
|
||
x_extra = x[:, -N_h * N_w:, :]
|
||
x = x[:, :-N_h * N_w, :]
|
||
N_t = N_t - 1
|
||
|
||
B = x.shape[0]
|
||
S = N_h * N_w
|
||
x = x.view(B * N_t, S, self.dim)
|
||
|
||
# get q for hidden_state
|
||
q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim)
|
||
|
||
# get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim)
|
||
kv = self.kv_linear(encoder_hidden_states)
|
||
encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2)
|
||
|
||
x = attention(q, encoder_k, encoder_v, attention_mode=self.attention_mode)
|
||
|
||
# linear transform
|
||
x = self.proj(x.reshape(B * N_t, S, self.dim))
|
||
x = x.view(B, N_t * S, self.dim)
|
||
|
||
if x_extra is not None:
|
||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||
|
||
return x
|
||
|
||
|
||
class SingleStreamMultiAttention(SingleStreamAttention):
|
||
"""Multi-speaker rotary-position cross-attention.
|
||
|
||
This implementation generalises the original 2-speaker logic to an arbitrary
|
||
number of voices. Each speaker is allocated a contiguous *class_interval*
|
||
segment inside a shared *class_range* rotary bucket. The centre of each
|
||
bucket is applied to that speaker's KV tokens while queries are modulated
|
||
per-token according to which speaker dominates the pixel.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
dim: int,
|
||
encoder_hidden_states_dim: int,
|
||
num_heads: int,
|
||
qkv_bias: bool,
|
||
class_range: int = 24,
|
||
class_interval: int = 4,
|
||
attention_mode: str = 'sdpa',
|
||
) -> None:
|
||
super().__init__(
|
||
dim=dim,
|
||
encoder_hidden_states_dim=encoder_hidden_states_dim,
|
||
num_heads=num_heads,
|
||
qkv_bias=qkv_bias,
|
||
attention_mode=attention_mode,
|
||
)
|
||
|
||
# Rotary-embedding layout parameters
|
||
self.class_interval = class_interval
|
||
self.class_range = class_range
|
||
self.max_humans = self.class_range // self.class_interval
|
||
|
||
# Constant bucket used for background tokens
|
||
self.rope_bak = int(self.class_range // 2)
|
||
|
||
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
|
||
|
||
self.attention_mode = attention_mode
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
encoder_hidden_states: torch.Tensor,
|
||
shape=None,
|
||
x_ref_attn_map=None,
|
||
human_num=None,
|
||
) -> torch.Tensor:
|
||
encoder_hidden_states = encoder_hidden_states.squeeze(0)
|
||
|
||
# Single-speaker fall-through
|
||
if human_num is None or human_num <= 1:
|
||
return super().forward(x, encoder_hidden_states, shape)
|
||
|
||
N_t, N_h, N_w = shape
|
||
|
||
x_extra = None
|
||
if x.shape[0] * N_t != encoder_hidden_states.shape[0]:
|
||
x_extra = x[:, -N_h * N_w:, :]
|
||
x = x[:, :-N_h * N_w, :]
|
||
N_t = N_t - 1
|
||
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
||
|
||
# Query projection
|
||
B, N, C = x.shape
|
||
q = self.q_linear(x)
|
||
q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||
|
||
if human_num == 2:
|
||
# Use `class_range` logic for exactly 2 speakers
|
||
rope_h1 = (0, self.class_interval)
|
||
rope_h2 = (self.class_range - self.class_interval, self.class_range)
|
||
rope_bak = int(self.class_range // 2)
|
||
|
||
# Normalize and scale attention maps for each speaker
|
||
max_values = x_ref_attn_map.max(1).values[:, None, None]
|
||
min_values = x_ref_attn_map.min(1).values[:, None, None]
|
||
max_min_values = torch.cat([max_values, min_values], dim=2)
|
||
|
||
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
|
||
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
|
||
|
||
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1)
|
||
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2)
|
||
back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device)
|
||
|
||
# Token-wise speaker dominance
|
||
max_indices = x_ref_attn_map.argmax(dim=0)
|
||
normalized_map = torch.stack([human1, human2, back], dim=1)
|
||
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
|
||
else:
|
||
# General case for more than 2 speakers
|
||
rope_ranges = [
|
||
(i * self.class_interval, (i + 1) * self.class_interval)
|
||
for i in range(human_num)
|
||
]
|
||
|
||
# Normalize each speaker's attention map into its own bucket
|
||
human_norm_list = []
|
||
for idx in range(human_num):
|
||
attn_map = x_ref_attn_map[idx]
|
||
att_min, att_max = attn_map.min(), attn_map.max()
|
||
human_norm = normalize_and_scale(
|
||
attn_map, (att_min, att_max), rope_ranges[idx]
|
||
)
|
||
human_norm_list.append(human_norm)
|
||
|
||
# Background constant bucket
|
||
back = torch.full(
|
||
(x_ref_attn_map.size(1),),
|
||
self.rope_bak,
|
||
dtype=x_ref_attn_map.dtype,
|
||
device=x_ref_attn_map.device,
|
||
)
|
||
|
||
# Token-wise speaker dominance
|
||
max_indices = x_ref_attn_map.argmax(dim=0)
|
||
normalized_map = torch.stack(human_norm_list + [back], dim=1)
|
||
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
|
||
|
||
# Apply rotary to Q
|
||
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||
q = self.rope_1d(q, normalized_pos)
|
||
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||
|
||
# Keys / Values
|
||
_, N_a, _ = encoder_hidden_states.shape
|
||
encoder_kv = self.kv_linear(encoder_hidden_states)
|
||
encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||
encoder_k, encoder_v = encoder_kv.unbind(0)
|
||
|
||
# Rotary for keys – assign centre of each speaker bucket to its context tokens
|
||
if human_num == 2:
|
||
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
|
||
per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2
|
||
per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2
|
||
encoder_pos = torch.cat([per_frame] * N_t, dim=0)
|
||
else:
|
||
tokens_per_human = N_a // human_num
|
||
encoder_pos_list = []
|
||
for i in range(human_num):
|
||
start, end = rope_ranges[i]
|
||
centre = (start + end) / 2
|
||
encoder_pos_list.append(
|
||
torch.full(
|
||
(tokens_per_human,), centre, dtype=encoder_k.dtype, device=encoder_k.device
|
||
)
|
||
)
|
||
encoder_pos = torch.cat(encoder_pos_list * N_t, dim=0)
|
||
|
||
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||
encoder_k = self.rope_1d(encoder_k, encoder_pos)
|
||
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||
|
||
# Final attention
|
||
q = rearrange(q, "B H M K -> B M H K")
|
||
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
||
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
||
x = attention(
|
||
q, encoder_k, encoder_v, attention_mode=self.attention_mode
|
||
)
|
||
|
||
# Linear projection
|
||
x = x.reshape(B, N, C)
|
||
x = self.proj(x)
|
||
|
||
# Restore original layout
|
||
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
||
if x_extra is not None:
|
||
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||
|
||
return x |