1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00
Files
kijai 64191921d4 Squashed commit of the following:
commit fdb23dec7d
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Jan 5 22:11:04 2026 +0200

    Update model.py

commit 07d7d8ca8e
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Jan 5 22:10:02 2026 +0200

    remove prints

commit 01869d4bf5
Merge: 55c6720 bf1d77f
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Jan 5 18:47:48 2026 +0200

    Merge branch 'main' into longvie2

commit 55c672028b
Merge: b551ec9 be41f67
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 29 15:39:43 2025 +0200

    Merge branch 'main' into longvie2

commit b551ec9e31
Merge: 9f019d7 19bcee6
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 29 15:03:53 2025 +0200

    Merge branch 'main' into longvie2

commit 9f019d7dfb
Merge: fc5322f c5d3fb4
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 23 23:40:25 2025 +0200

    Merge branch 'main' into longvie2

commit fc5322fae4
Merge: 222fc70 e75f814
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 23 22:04:15 2025 +0200

    Merge branch 'main' into longvie2

commit 222fc70eb7
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 23 17:18:55 2025 +0200

    Update nodes.py

commit 8509236da1
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 23 14:20:18 2025 +0200

    init
2026-01-05 22:11:20 +02:00

200 lines
7.2 KiB
Python

import torch
import torch.nn as nn
from einops import rearrange
from ..wanvideo.modules.attention import attention
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
return (x * (1 + scale) + shift)
def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
return x.to(position.dtype)
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
# 3d rope precompute
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
return f_freqs_cis, h_freqs_cis, w_freqs_cis
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
# 1d rope precompute
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
[: (dim // 2)].double() / dim))
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def rope_apply(x, freqs, num_heads):
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
x.shape[0], x.shape[1], x.shape[2], -1, 2))
x_out = torch.view_as_real(x_out * freqs).flatten(2)
return x_out.to(x.dtype)
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
dtype = x.dtype
return self.norm(x.float()).to(dtype) * self.weight
class AttentionModule(nn.Module):
def __init__(self, num_heads, head_dim):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
def forward(self, q, k, v):
b, n, d = q.size(0), self.num_heads, self.head_dim
x = attention(
q.view(b, -1, n, d),
k.view(b, -1, n, d),
v.view(b, -1, n, d)
)
return x.flatten(2)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads, self.head_dim)
def forward(self, x, freqs):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
q = rope_apply(q, freqs, self.num_heads)
k = rope_apply(k, freqs, self.num_heads)
x = self.attn(q, k, v)
return self.o(x)
class CrossAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, clip_fea: torch.Tensor = None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = RMSNorm(dim, eps=eps)
self.norm_k = RMSNorm(dim, eps=eps)
self.k_img = nn.Linear(dim, dim)
self.v_img = nn.Linear(dim, dim)
self.norm_k_img = RMSNorm(dim, eps=eps)
self.attn = AttentionModule(self.num_heads, self.head_dim)
def forward(self, x: torch.Tensor, y: torch.Tensor, clip_fea: torch.Tensor = None):
ctx = y
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(ctx))
v = self.v(ctx)
x = self.attn(q, k, v)
if clip_fea is not None:
k_img = self.norm_k_img(self.k_img(clip_fea))
v_img = self.v_img(clip_fea)
y = self.attn(q, k_img, v_img)
x = x + y
return self.o(x)
class GateModule(nn.Module):
def __init__(self,):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class DiTBlock(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.ffn_dim = ffn_dim
self.self_attn = SelfAttention(dim, num_heads, eps)
self.cross_attn = CrossAttention(dim, num_heads, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
self.norm3 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
approximate='tanh'), nn.Linear(ffn_dim, dim))
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs, clip_fea=None):
has_seq = len(t_mod.shape) == 4
chunk_dim = 2 if has_seq else 1
# msa: multi-head self-attention mlp: multi-layer perceptron
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=chunk_dim)
if has_seq:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
shift_msa.squeeze(2), scale_msa.squeeze(2), gate_msa.squeeze(2),
shift_mlp.squeeze(2), scale_mlp.squeeze(2), gate_mlp.squeeze(2),
)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
x = x + self.cross_attn(self.norm3(x), context, clip_fea=clip_fea)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x
class WanModelDualControl(torch.nn.Module):
def __init__(self, dim: int, ffn_dim: int, eps: float, num_heads: int, control_layers = 12):
super().__init__()
self.control_layers = control_layers
self.control_blocks_dense = nn.ModuleList([
DiTBlock(dim//2, num_heads//2, ffn_dim//2, eps)
for _ in range(self.control_layers)
])
self.control_blocks_sparse = nn.ModuleList([
DiTBlock(dim//2, num_heads//2, ffn_dim//2, eps)
for _ in range(self.control_layers)
])
self.control_initial_combine_linear_dense = torch.nn.Linear(dim, dim//2)
self.control_initial_combine_linear_sparse = torch.nn.Linear(dim, dim//2)
self.control_text_linear = torch.nn.Linear(dim, dim//2)
self.control_t_mod = torch.nn.Linear(dim, dim//2)
self.control_combine_linears = torch.nn.ModuleList([torch.nn.Linear(dim//2, dim) for _ in range(self.control_layers)])
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)