You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
370 lines
14 KiB
Python
370 lines
14 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import numpy as np
|
||
from diffusers.models import ModelMixin
|
||
from typing import Optional, Tuple, Union
|
||
import torch.nn.functional as F
|
||
from diffusers.models.attention_processor import Attention
|
||
|
||
from einops import rearrange
|
||
|
||
def get_1d_rotary_pos_embed(
|
||
dim: int,
|
||
pos: Union[np.ndarray, int],
|
||
theta: float = 10000.0,
|
||
use_real=False,
|
||
linear_factor=1.0,
|
||
ntk_factor=1.0,
|
||
repeat_interleave_real=True,
|
||
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
|
||
):
|
||
"""
|
||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||
|
||
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
|
||
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
|
||
data type.
|
||
|
||
Args:
|
||
dim (`int`): Dimension of the frequency tensor.
|
||
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
|
||
theta (`float`, *optional*, defaults to 10000.0):
|
||
Scaling factor for frequency computation. Defaults to 10000.0.
|
||
use_real (`bool`, *optional*):
|
||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||
linear_factor (`float`, *optional*, defaults to 1.0):
|
||
Scaling factor for the context extrapolation. Defaults to 1.0.
|
||
ntk_factor (`float`, *optional*, defaults to 1.0):
|
||
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
||
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
||
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
||
Otherwise, they are concateanted with themselves.
|
||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
||
the dtype of the frequency tensor.
|
||
Returns:
|
||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||
"""
|
||
assert dim % 2 == 0
|
||
|
||
if isinstance(pos, int):
|
||
pos = torch.arange(pos)
|
||
if isinstance(pos, np.ndarray):
|
||
pos = torch.from_numpy(pos) # type: ignore # [S]
|
||
|
||
theta = theta * ntk_factor
|
||
freqs = (
|
||
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
|
||
) # [D/2]
|
||
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
||
is_npu = freqs.device.type == "npu"
|
||
if is_npu:
|
||
freqs = freqs.float()
|
||
if use_real and repeat_interleave_real:
|
||
# flux, hunyuan-dit, cogvideox
|
||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
||
return freqs_cos, freqs_sin
|
||
elif use_real:
|
||
# stable audio, allegro
|
||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||
return freqs_cos, freqs_sin
|
||
else:
|
||
# lumina
|
||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||
return freqs_cis
|
||
|
||
class WanRotaryPosEmbed(nn.Module):
|
||
def __init__(
|
||
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
|
||
):
|
||
super().__init__()
|
||
|
||
self.attention_head_dim = attention_head_dim
|
||
self.patch_size = patch_size
|
||
self.max_seq_len = max_seq_len
|
||
|
||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||
t_dim = attention_head_dim - h_dim - w_dim
|
||
|
||
freqs = []
|
||
for dim in [t_dim, h_dim, w_dim]:
|
||
freq = get_1d_rotary_pos_embed(
|
||
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
|
||
)
|
||
freqs.append(freq)
|
||
self.freqs = torch.cat(freqs, dim=1)
|
||
|
||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||
p_t, p_h, p_w = self.patch_size
|
||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||
|
||
self.freqs = self.freqs.to(hidden_states.device)
|
||
freqs = self.freqs.split_with_sizes(
|
||
[
|
||
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
|
||
self.attention_head_dim // 6,
|
||
self.attention_head_dim // 6,
|
||
],
|
||
dim=1,
|
||
)
|
||
|
||
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
||
return freqs
|
||
|
||
from ..wanvideo.modules.attention import sageattn_func
|
||
|
||
class SimpleAttnProcessor2_0:
|
||
def __init__(self, attention_mode):
|
||
self.attention_mode = attention_mode
|
||
def __call__(
|
||
self,
|
||
attn: Attention,
|
||
hidden_states: torch.Tensor,
|
||
attention_mask: Optional[torch.Tensor] = None,
|
||
rotary_emb: Optional[torch.Tensor] = None,
|
||
**kwargs
|
||
) -> torch.Tensor:
|
||
|
||
query = attn.to_q(hidden_states)
|
||
key = attn.to_k(hidden_states)
|
||
value = attn.to_v(hidden_states)
|
||
|
||
if attn.norm_q is not None:
|
||
query = attn.norm_q(query)
|
||
if attn.norm_k is not None:
|
||
key = attn.norm_k(key)
|
||
|
||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) # [b,head,l,c]
|
||
|
||
if rotary_emb is not None:
|
||
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
|
||
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
|
||
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
|
||
return x_out.type_as(hidden_states)
|
||
|
||
query = apply_rotary_emb(query, rotary_emb)
|
||
key = apply_rotary_emb(key, rotary_emb)
|
||
|
||
if self.attention_mode == 'sdpa':
|
||
hidden_states = F.scaled_dot_product_attention(
|
||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||
)
|
||
elif self.attention_mode == 'sageattn':
|
||
hidden_states = sageattn_func(
|
||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||
)
|
||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||
hidden_states = hidden_states.type_as(query)
|
||
|
||
hidden_states = attn.to_out[0](hidden_states)
|
||
hidden_states = attn.to_out[1](hidden_states)
|
||
return hidden_states
|
||
|
||
|
||
class SimpleCogVideoXLayerNormZero(nn.Module):
|
||
def __init__(
|
||
self,
|
||
conditioning_dim: int,
|
||
embedding_dim: int,
|
||
elementwise_affine: bool = True,
|
||
eps: float = 1e-5,
|
||
bias: bool = True,
|
||
) -> None:
|
||
super().__init__()
|
||
|
||
self.silu = nn.SiLU()
|
||
self.linear = nn.Linear(conditioning_dim, 3 * embedding_dim, bias=bias)
|
||
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||
|
||
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor):
|
||
shift, scale, gate = self.linear(self.silu(temb)).chunk(3, dim=1)
|
||
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||
return hidden_states, gate[:, None, :]
|
||
|
||
|
||
class SingleAttentionBlock(nn.Module):
|
||
|
||
def __init__(
|
||
self,
|
||
dim,
|
||
ffn_dim,
|
||
num_heads,
|
||
time_embed_dim=512,
|
||
qk_norm="rms_norm_across_heads",
|
||
eps=1e-6,
|
||
attention_mode="sdpa",
|
||
):
|
||
super().__init__()
|
||
self.dim = dim
|
||
self.ffn_dim = ffn_dim
|
||
self.num_heads = num_heads
|
||
self.qk_norm = qk_norm
|
||
self.eps = eps
|
||
|
||
# layers
|
||
self.norm1 = SimpleCogVideoXLayerNormZero(
|
||
time_embed_dim, dim, elementwise_affine=True, eps=1e-5, bias=True
|
||
)
|
||
self.self_attn = Attention(
|
||
query_dim=dim,
|
||
heads=num_heads,
|
||
kv_heads=num_heads,
|
||
dim_head=dim // num_heads,
|
||
qk_norm=qk_norm,
|
||
eps=eps,
|
||
bias=True,
|
||
cross_attention_dim=None,
|
||
out_bias=True,
|
||
processor=SimpleAttnProcessor2_0(attention_mode),
|
||
)
|
||
self.norm2 = SimpleCogVideoXLayerNormZero(
|
||
time_embed_dim, dim, elementwise_affine=True, eps=1e-5, bias=True
|
||
)
|
||
self.ffn = nn.Sequential(
|
||
nn.Linear(dim, ffn_dim),
|
||
nn.GELU(approximate='tanh'),
|
||
nn.Linear(ffn_dim, dim)
|
||
)
|
||
|
||
def forward(
|
||
self,
|
||
hidden_states,
|
||
temb,
|
||
rotary_emb,
|
||
):
|
||
# norm & modulate
|
||
norm_hidden_states, gate_msa = self.norm1(hidden_states, temb)
|
||
|
||
# attention
|
||
attn_hidden_states = self.self_attn(hidden_states=norm_hidden_states,
|
||
rotary_emb=rotary_emb)
|
||
|
||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||
|
||
# norm & modulate
|
||
norm_hidden_states, gate_ff = self.norm2(hidden_states, temb)
|
||
|
||
# feed-forward
|
||
ff_output = self.ffn(norm_hidden_states)
|
||
|
||
hidden_states = hidden_states + gate_ff * ff_output
|
||
|
||
return hidden_states
|
||
|
||
class MaskCamEmbed(nn.Module):
|
||
def __init__(self, controlnet_cfg) -> None:
|
||
super().__init__()
|
||
|
||
# padding bug fixed
|
||
if controlnet_cfg.get("interp", False):
|
||
self.mask_padding = [0, 0, 0, 0, 3, 3] # 左右上下前后, I2V-interp,首尾帧
|
||
else:
|
||
self.mask_padding = [0, 0, 0, 0, 3, 0] # 左右上下前后, I2V
|
||
add_channels = controlnet_cfg.get("add_channels", 1)
|
||
mid_channels = controlnet_cfg.get("mid_channels", 64)
|
||
self.mask_proj = nn.Sequential(nn.Conv3d(add_channels, mid_channels, kernel_size=(4, 8, 8), stride=(4, 8, 8)),
|
||
nn.GroupNorm(mid_channels // 8, mid_channels), nn.SiLU())
|
||
self.mask_zero_proj = nn.Conv3d(mid_channels, controlnet_cfg["conv_out_dim"], kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||
|
||
def forward(self, add_inputs: torch.Tensor):
|
||
# render_mask.shape [b,c,f,h,w]
|
||
warp_add_pad = F.pad(add_inputs, self.mask_padding, mode="constant", value=0)
|
||
add_embeds = self.mask_proj(warp_add_pad) # [B,C,F,H,W]
|
||
add_embeds = self.mask_zero_proj(add_embeds)
|
||
add_embeds = rearrange(add_embeds, "b c f h w -> b (f h w) c")
|
||
|
||
return add_embeds
|
||
|
||
class WanControlNet(ModelMixin):
|
||
def __init__(self, controlnet_cfg):
|
||
super().__init__()
|
||
|
||
self.rope_max_seq_len = 1024
|
||
self.patch_size = (1, 2, 2)
|
||
self.in_channels = controlnet_cfg["in_channels"]
|
||
self.dim = controlnet_cfg["dim"]
|
||
self.num_heads = controlnet_cfg["num_heads"]
|
||
self.quantized = controlnet_cfg["quantized"]
|
||
self.base_dtype = controlnet_cfg["base_dtype"]
|
||
|
||
if controlnet_cfg["conv_out_dim"] != controlnet_cfg["dim"]:
|
||
self.proj_in = nn.Linear(controlnet_cfg["conv_out_dim"], controlnet_cfg["dim"])
|
||
else:
|
||
self.proj_in = nn.Identity()
|
||
|
||
self.controlnet_blocks = nn.ModuleList(
|
||
[
|
||
SingleAttentionBlock(
|
||
dim=self.dim,
|
||
ffn_dim=controlnet_cfg["ffn_dim"],
|
||
num_heads=self.num_heads,
|
||
time_embed_dim=controlnet_cfg["time_embed_dim"],
|
||
qk_norm="rms_norm_across_heads",
|
||
attention_mode=controlnet_cfg["attention_mode"],
|
||
)
|
||
for _ in range(controlnet_cfg["num_layers"])
|
||
]
|
||
)
|
||
self.proj_out = nn.ModuleList(
|
||
[
|
||
nn.Linear(self.dim, 5120)
|
||
for _ in range(controlnet_cfg["num_layers"])
|
||
]
|
||
)
|
||
|
||
self.gradient_checkpointing = False
|
||
|
||
self.controlnet_rope = WanRotaryPosEmbed(self.dim // self.num_heads,
|
||
self.patch_size, self.rope_max_seq_len)
|
||
|
||
self.controlnet_patch_embedding = nn.Conv3d(
|
||
self.in_channels,
|
||
controlnet_cfg["conv_out_dim"],
|
||
kernel_size=self.patch_size,
|
||
stride=self.patch_size,
|
||
dtype=torch.float32
|
||
)
|
||
|
||
self.controlnet_mask_embedding = MaskCamEmbed(controlnet_cfg)
|
||
|
||
def forward(self, render_latent, render_mask, camera_embedding, temb, out_device):
|
||
controlnet_rotary_emb = self.controlnet_rope(render_latent)
|
||
controlnet_inputs = self.controlnet_patch_embedding(render_latent.to(torch.float32))
|
||
if not self.quantized:
|
||
controlnet_inputs = controlnet_inputs.to(render_latent.dtype)
|
||
else:
|
||
controlnet_inputs = controlnet_inputs.to(self.base_dtype)
|
||
|
||
controlnet_inputs = controlnet_inputs.flatten(2).transpose(1, 2)
|
||
|
||
# additional inputs (mask, camera embedding)
|
||
add_inputs = None
|
||
if camera_embedding is not None and render_mask is not None:
|
||
add_inputs = torch.cat([render_mask, camera_embedding], dim=1)
|
||
elif render_mask is not None:
|
||
add_inputs = render_mask
|
||
|
||
if add_inputs is not None:
|
||
add_inputs = self.controlnet_mask_embedding(add_inputs)
|
||
controlnet_inputs = controlnet_inputs + add_inputs
|
||
|
||
hidden_states = self.proj_in(controlnet_inputs)
|
||
|
||
controlnet_states = []
|
||
for i, block in enumerate(self.controlnet_blocks):
|
||
hidden_states = block(
|
||
hidden_states=hidden_states,
|
||
temb=temb,
|
||
rotary_emb=controlnet_rotary_emb
|
||
)
|
||
controlnet_states.append(self.proj_out[i](hidden_states).to(out_device))
|
||
|
||
return controlnet_states
|