You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
commitfdb23dec7dAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Jan 5 22:11:04 2026 +0200 Update model.py commit07d7d8ca8eAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Jan 5 22:10:02 2026 +0200 remove prints commit01869d4bf5Merge:55c6720bf1d77fAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Jan 5 18:47:48 2026 +0200 Merge branch 'main' into longvie2 commit55c672028bMerge:b551ec9be41f67Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 29 15:39:43 2025 +0200 Merge branch 'main' into longvie2 commitb551ec9e31Merge:9f019d719bcee6Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 29 15:03:53 2025 +0200 Merge branch 'main' into longvie2 commit9f019d7dfbMerge:fc5322fc5d3fb4Author: kijai <40791699+kijai@users.noreply.github.com> Date: Tue Dec 23 23:40:25 2025 +0200 Merge branch 'main' into longvie2 commitfc5322fae4Merge:222fc70e75f814Author: kijai <40791699+kijai@users.noreply.github.com> Date: Tue Dec 23 22:04:15 2025 +0200 Merge branch 'main' into longvie2 commit222fc70eb7Author: kijai <40791699+kijai@users.noreply.github.com> Date: Tue Dec 23 17:18:55 2025 +0200 Update nodes.py commit8509236da1Author: kijai <40791699+kijai@users.noreply.github.com> Date: Tue Dec 23 14:20:18 2025 +0200 init
200 lines
7.2 KiB
Python
200 lines
7.2 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from einops import rearrange
|
|
|
|
from ..wanvideo.modules.attention import attention
|
|
|
|
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
|
|
return (x * (1 + scale) + shift)
|
|
|
|
|
|
def sinusoidal_embedding_1d(dim, position):
|
|
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
|
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
|
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
|
return x.to(position.dtype)
|
|
|
|
|
|
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
|
# 3d rope precompute
|
|
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
|
|
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
|
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
|
return f_freqs_cis, h_freqs_cis, w_freqs_cis
|
|
|
|
|
|
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
|
# 1d rope precompute
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
|
|
[: (dim // 2)].double() / dim))
|
|
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
|
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
|
return freqs_cis
|
|
|
|
|
|
def rope_apply(x, freqs, num_heads):
|
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
|
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
|
return x_out.to(x.dtype)
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dim, eps=1e-5):
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
|
|
def norm(self, x):
|
|
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
|
|
|
def forward(self, x):
|
|
dtype = x.dtype
|
|
return self.norm(x.float()).to(dtype) * self.weight
|
|
|
|
|
|
class AttentionModule(nn.Module):
|
|
def __init__(self, num_heads, head_dim):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.head_dim = head_dim
|
|
|
|
def forward(self, q, k, v):
|
|
b, n, d = q.size(0), self.num_heads, self.head_dim
|
|
x = attention(
|
|
q.view(b, -1, n, d),
|
|
k.view(b, -1, n, d),
|
|
v.view(b, -1, n, d)
|
|
)
|
|
return x.flatten(2)
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
|
|
self.q = nn.Linear(dim, dim)
|
|
self.k = nn.Linear(dim, dim)
|
|
self.v = nn.Linear(dim, dim)
|
|
self.o = nn.Linear(dim, dim)
|
|
self.norm_q = RMSNorm(dim, eps=eps)
|
|
self.norm_k = RMSNorm(dim, eps=eps)
|
|
|
|
self.attn = AttentionModule(self.num_heads, self.head_dim)
|
|
|
|
def forward(self, x, freqs):
|
|
q = self.norm_q(self.q(x))
|
|
k = self.norm_k(self.k(x))
|
|
v = self.v(x)
|
|
q = rope_apply(q, freqs, self.num_heads)
|
|
k = rope_apply(k, freqs, self.num_heads)
|
|
x = self.attn(q, k, v)
|
|
return self.o(x)
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, clip_fea: torch.Tensor = None):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
|
|
self.q = nn.Linear(dim, dim)
|
|
self.k = nn.Linear(dim, dim)
|
|
self.v = nn.Linear(dim, dim)
|
|
self.o = nn.Linear(dim, dim)
|
|
self.norm_q = RMSNorm(dim, eps=eps)
|
|
self.norm_k = RMSNorm(dim, eps=eps)
|
|
|
|
|
|
self.k_img = nn.Linear(dim, dim)
|
|
self.v_img = nn.Linear(dim, dim)
|
|
self.norm_k_img = RMSNorm(dim, eps=eps)
|
|
|
|
self.attn = AttentionModule(self.num_heads, self.head_dim)
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor, clip_fea: torch.Tensor = None):
|
|
ctx = y
|
|
q = self.norm_q(self.q(x))
|
|
k = self.norm_k(self.k(ctx))
|
|
v = self.v(ctx)
|
|
x = self.attn(q, k, v)
|
|
if clip_fea is not None:
|
|
k_img = self.norm_k_img(self.k_img(clip_fea))
|
|
v_img = self.v_img(clip_fea)
|
|
y = self.attn(q, k_img, v_img)
|
|
x = x + y
|
|
return self.o(x)
|
|
|
|
|
|
class GateModule(nn.Module):
|
|
def __init__(self,):
|
|
super().__init__()
|
|
|
|
def forward(self, x, gate, residual):
|
|
return x + gate * residual
|
|
|
|
class DiTBlock(nn.Module):
|
|
def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.num_heads = num_heads
|
|
self.ffn_dim = ffn_dim
|
|
|
|
self.self_attn = SelfAttention(dim, num_heads, eps)
|
|
self.cross_attn = CrossAttention(dim, num_heads, eps)
|
|
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
|
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
|
self.norm3 = nn.LayerNorm(dim, eps=eps)
|
|
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
|
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
|
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
|
self.gate = GateModule()
|
|
|
|
def forward(self, x, context, t_mod, freqs, clip_fea=None):
|
|
has_seq = len(t_mod.shape) == 4
|
|
chunk_dim = 2 if has_seq else 1
|
|
# msa: multi-head self-attention mlp: multi-layer perceptron
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
|
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)
|
|
if has_seq:
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
|
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
|
|
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
|
|
)
|
|
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
|
|
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
|
|
x = x + self.cross_attn(self.norm3(x), context, clip_fea=clip_fea)
|
|
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
|
|
x = self.gate(x, gate_mlp, self.ffn(input_x))
|
|
return x
|
|
|
|
|
|
class WanModelDualControl(torch.nn.Module):
|
|
def __init__(self, dim: int, ffn_dim: int, eps: float, num_heads: int, control_layers = 12):
|
|
super().__init__()
|
|
self.control_layers = control_layers
|
|
self.control_blocks_dense = nn.ModuleList([
|
|
DiTBlock(dim//2, num_heads//2, ffn_dim//2, eps)
|
|
for _ in range(self.control_layers)
|
|
])
|
|
|
|
self.control_blocks_sparse = nn.ModuleList([
|
|
DiTBlock(dim//2, num_heads//2, ffn_dim//2, eps)
|
|
for _ in range(self.control_layers)
|
|
])
|
|
|
|
self.control_initial_combine_linear_dense = torch.nn.Linear(dim, dim//2)
|
|
self.control_initial_combine_linear_sparse = torch.nn.Linear(dim, dim//2)
|
|
|
|
self.control_text_linear = torch.nn.Linear(dim, dim//2)
|
|
self.control_t_mod = torch.nn.Linear(dim, dim//2)
|
|
|
|
self.control_combine_linears = torch.nn.ModuleList([torch.nn.Linear(dim//2, dim) for _ in range(self.control_layers)])
|
|
head_dim = dim // num_heads
|
|
self.freqs = precompute_freqs_cis_3d(head_dim)
|