You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
commitc3eb0f49faAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Tue Dec 9 12:55:49 2025 +0200 move workflow commite129e25c26Author: kijai <40791699+kijai@users.noreply.github.com> Date: Tue Dec 9 11:17:17 2025 +0200 Fix padding commitf252f34effAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Tue Dec 9 01:38:17 2025 +0200 Add long video example commit09ceab808bAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Tue Dec 9 01:31:48 2025 +0200 Support extension commit7ca221874eAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 12:28:29 2025 +0200 Might as well not even do control pass on uncond... commitb55caf299eAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 12:15:59 2025 +0200 Cfg fixes commitfd54ba23e6Merge:2f97b1be867e64Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 10:39:55 2025 +0200 Merge branch 'main' into onetoall commit2f97b1bd88Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 09:32:09 2025 +0200 Add ref_mask input commit74cad232fdAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 03:44:42 2025 +0200 Update nodes_model_loading.py commit01a038eb4aAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 03:11:08 2025 +0200 Fix indentation commita95f4d6eaaAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 02:54:47 2025 +0200 Update model.py commitad006985a1Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 02:54:19 2025 +0200 Fix token replace commitb5f0f44f17Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 02:50:52 2025 +0200 Don't use token replace by default commit874174ec29Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 02:24:47 2025 +0200 Create WanToAllAnimation_test.json commit9e61758556Author: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 02:23:15 2025 +0200 Add token replacement commit41fd76dfcbAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 00:45:33 2025 +0200 Use correct norm for reference attn commit705f5dcc8bAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Mon Dec 8 00:11:17 2025 +0200 cleanup commit4f095d97f8Merge:3e4e4db2369cdbAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sun Dec 7 18:44:01 2025 +0200 Merge branch 'main' into onetoall commit3e4e4db35dAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Sun Dec 7 00:27:23 2025 +0200 handle controlnet better commitc5742552a9Author: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 6 17:24:45 2025 +0200 cleanup commitc06ff9c066Author: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 6 03:41:02 2025 +0200 3D rope for controlnet commit948ea6b783Author: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 6 03:08:04 2025 +0200 pose input scaling commit90c2eff3b2Author: kijai <40791699+kijai@users.noreply.github.com> Date: Sat Dec 6 02:37:48 2025 +0200 Cleanup commit9f7683422cAuthor: kijai <40791699+kijai@users.noreply.github.com> Date: Fri Dec 5 23:29:05 2025 +0200 pose control commit0f217be4d8Author: kijai <40791699+kijai@users.noreply.github.com> Date: Fri Dec 5 20:55:10 2025 +0200 Support reference input
237 lines
9.9 KiB
Python
237 lines
9.9 KiB
Python
# source https://github.com/TheDenk/wan2.1-dilated-controlnet/blob/main/wan_controlnet.py
|
|
from typing import Any, Dict, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
|
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
from diffusers.models.transformers.transformer_wan import (
|
|
WanTimeTextImageEmbedding,
|
|
WanRotaryPosEmbed,
|
|
WanTransformerBlock
|
|
)
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
class WanControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
|
r"""
|
|
A Controlnet Transformer model for video-like data used in the Wan model.
|
|
|
|
Args:
|
|
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
|
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
|
num_attention_heads (`int`, defaults to `40`):
|
|
Fixed length for text embeddings.
|
|
attention_head_dim (`int`, defaults to `128`):
|
|
The number of channels in each head.
|
|
vae_channels (`int`, defaults to `16`):
|
|
The number of channels in the vae input.
|
|
in_channels (`int`, defaults to `16`):
|
|
The number of channels in the controlnet input.
|
|
text_dim (`int`, defaults to `512`):
|
|
Input dimension for text embeddings.
|
|
freq_dim (`int`, defaults to `256`):
|
|
Dimension for sinusoidal time embeddings.
|
|
ffn_dim (`int`, defaults to `13824`):
|
|
Intermediate dimension in feed-forward network.
|
|
num_layers (`int`, defaults to `40`):
|
|
The number of layers of transformer blocks to use.
|
|
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
|
Window size for local attention (-1 indicates global attention).
|
|
cross_attn_norm (`bool`, defaults to `True`):
|
|
Enable cross-attention normalization.
|
|
qk_norm (`bool`, defaults to `True`):
|
|
Enable query/key normalization.
|
|
eps (`float`, defaults to `1e-6`):
|
|
Epsilon value for normalization layers.
|
|
add_img_emb (`bool`, defaults to `False`):
|
|
Whether to use img_emb.
|
|
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
|
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
|
downscale_coef (`int`, *optional*, defaults to `8`):
|
|
Coeficient for downscale controlnet input video.
|
|
out_proj_dim (`int`, *optional*, defaults to `128 * 12`):
|
|
Output projection dimention for last linear layers.
|
|
"""
|
|
|
|
_supports_gradient_checkpointing = True
|
|
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
|
|
_no_split_modules = ["WanTransformerBlock"]
|
|
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
|
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
|
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
patch_size: Tuple[int] = (1, 2, 2),
|
|
num_attention_heads: int = 40,
|
|
attention_head_dim: int = 128,
|
|
in_channels: int = 3,
|
|
vae_channels: int = 16,
|
|
text_dim: int = 4096,
|
|
freq_dim: int = 256,
|
|
ffn_dim: int = 13824,
|
|
num_layers: int = 20,
|
|
cross_attn_norm: bool = True,
|
|
qk_norm: Optional[str] = "rms_norm_across_heads",
|
|
eps: float = 1e-6,
|
|
image_dim: Optional[int] = None,
|
|
added_kv_proj_dim: Optional[int] = None,
|
|
rope_max_seq_len: int = 1024,
|
|
downscale_coef: int = 8,
|
|
out_proj_dim: int = 128 * 12,
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
start_channels = in_channels * (downscale_coef ** 2)
|
|
input_channels = [start_channels, start_channels // 2, start_channels // 4]
|
|
|
|
self.control_encoder = nn.ModuleList([
|
|
## Spatial compression with time awareness
|
|
nn.Sequential(
|
|
nn.Conv3d(
|
|
in_channels,
|
|
input_channels[0],
|
|
kernel_size=(3, downscale_coef + 1, downscale_coef + 1),
|
|
stride=(1, downscale_coef, downscale_coef),
|
|
padding=(1, downscale_coef // 2, downscale_coef // 2)
|
|
),
|
|
nn.GELU(approximate="tanh"),
|
|
nn.GroupNorm(2, input_channels[0]),
|
|
),
|
|
## Spatio-Temporal compression with spatial awareness
|
|
nn.Sequential(
|
|
nn.Conv3d(input_channels[0], input_channels[1], kernel_size=3, stride=(2, 1, 1), padding=1),
|
|
nn.GELU(approximate="tanh"),
|
|
nn.GroupNorm(2, input_channels[1]),
|
|
),
|
|
## Temporal compression with spatial awareness
|
|
nn.Sequential(
|
|
nn.Conv3d(input_channels[1], input_channels[2], kernel_size=3, stride=(2, 1, 1), padding=1),
|
|
nn.GELU(approximate="tanh"),
|
|
nn.GroupNorm(2, input_channels[2]),
|
|
)
|
|
])
|
|
|
|
inner_dim = num_attention_heads * attention_head_dim
|
|
|
|
# 1. Patch & position embedding
|
|
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
|
self.patch_embedding = nn.Conv3d(vae_channels + input_channels[2], inner_dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
# 2. Condition embeddings
|
|
# image_embedding_dim=1280 for I2V model
|
|
self.condition_embedder = WanTimeTextImageEmbedding(
|
|
dim=inner_dim,
|
|
time_freq_dim=freq_dim,
|
|
time_proj_dim=inner_dim * 6,
|
|
text_embed_dim=text_dim,
|
|
image_embed_dim=image_dim,
|
|
)
|
|
# 3. Transformer blocks
|
|
self.blocks = nn.ModuleList(
|
|
[
|
|
WanTransformerBlock(
|
|
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
|
)
|
|
for _ in range(num_layers)
|
|
]
|
|
)
|
|
|
|
# 4 Controlnet modules
|
|
self.controlnet_blocks = nn.ModuleList([])
|
|
|
|
for _ in range(len(self.blocks)):
|
|
controlnet_block = nn.Linear(inner_dim, out_proj_dim)
|
|
self.controlnet_blocks.append(controlnet_block)
|
|
|
|
self.gradient_checkpointing = False
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
timestep: torch.LongTensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
controlnet_states: torch.Tensor,
|
|
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
|
return_dict: bool = True,
|
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
if attention_kwargs is not None:
|
|
attention_kwargs = attention_kwargs.copy()
|
|
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
else:
|
|
lora_scale = 1.0
|
|
|
|
if USE_PEFT_BACKEND:
|
|
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
scale_lora_layers(self, lora_scale)
|
|
else:
|
|
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
|
logger.warning(
|
|
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
|
)
|
|
rotary_emb = self.rope(hidden_states)
|
|
|
|
# 0. Controlnet encoder
|
|
for control_encoder_block in self.control_encoder:
|
|
controlnet_states = control_encoder_block(controlnet_states)
|
|
|
|
hidden_states = torch.cat([hidden_states, controlnet_states], dim=1)
|
|
|
|
## 1. Patch embedding and stack
|
|
hidden_states = self.patch_embedding(hidden_states)
|
|
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
|
|
|
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
|
|
if timestep.ndim == 2:
|
|
## for ComfyUI workflow
|
|
if hidden_states.shape[1] != timestep.shape[1]:
|
|
timestep = timestep.repeat_interleave(hidden_states.shape[1] // timestep.shape[1], dim=1)
|
|
ts_seq_len = timestep.shape[1]
|
|
timestep = timestep.flatten() # batch_size * seq_len
|
|
else:
|
|
ts_seq_len = None
|
|
|
|
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
|
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
|
|
)
|
|
if ts_seq_len is not None:
|
|
# batch_size, seq_len, 6, inner_dim
|
|
timestep_proj = timestep_proj.unflatten(2, (6, -1))
|
|
else:
|
|
# batch_size, 6, inner_dim
|
|
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
|
|
|
if encoder_hidden_states_image is not None:
|
|
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
|
|
|
# 4. Transformer blocks
|
|
controlnet_hidden_states = ()
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
|
hidden_states = self._gradient_checkpointing_func(
|
|
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
|
)
|
|
controlnet_hidden_states += (controlnet_block(hidden_states),)
|
|
else:
|
|
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
|
|
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
|
controlnet_hidden_states += (controlnet_block(hidden_states),)
|
|
|
|
|
|
if USE_PEFT_BACKEND:
|
|
# remove `lora_scale` from each PEFT layer
|
|
unscale_lora_layers(self, lora_scale)
|
|
|
|
if not return_dict:
|
|
return (controlnet_hidden_states,)
|
|
|
|
return Transformer2DModelOutput(sample=controlnet_hidden_states)
|
|
|