1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00
Files
ComfyUI-WanVideoWrapper/controlnet/wan_controlnet.py
kijai 8b037bce2e Squashed commit of the following:
commit c3eb0f49fa
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 9 12:55:49 2025 +0200

    move workflow

commit e129e25c26
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 9 11:17:17 2025 +0200

    Fix padding

commit f252f34eff
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 9 01:38:17 2025 +0200

    Add long video example

commit 09ceab808b
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Tue Dec 9 01:31:48 2025 +0200

    Support extension

commit 7ca221874e
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 12:28:29 2025 +0200

    Might as well not even do control pass on uncond...

commit b55caf299e
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 12:15:59 2025 +0200

    Cfg fixes

commit fd54ba23e6
Merge: 2f97b1b e867e64
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 10:39:55 2025 +0200

    Merge branch 'main' into onetoall

commit 2f97b1bd88
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 09:32:09 2025 +0200

    Add ref_mask input

commit 74cad232fd
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 03:44:42 2025 +0200

    Update nodes_model_loading.py

commit 01a038eb4a
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 03:11:08 2025 +0200

    Fix indentation

commit a95f4d6eaa
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 02:54:47 2025 +0200

    Update model.py

commit ad006985a1
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 02:54:19 2025 +0200

    Fix token replace

commit b5f0f44f17
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 02:50:52 2025 +0200

    Don't use token replace by default

commit 874174ec29
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 02:24:47 2025 +0200

    Create WanToAllAnimation_test.json

commit 9e61758556
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 02:23:15 2025 +0200

    Add token replacement

commit 41fd76dfcb
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 00:45:33 2025 +0200

    Use correct norm for reference attn

commit 705f5dcc8b
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Dec 8 00:11:17 2025 +0200

    cleanup

commit 4f095d97f8
Merge: 3e4e4db 2369cdb
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sun Dec 7 18:44:01 2025 +0200

    Merge branch 'main' into onetoall

commit 3e4e4db35d
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sun Dec 7 00:27:23 2025 +0200

    handle controlnet better

commit c5742552a9
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 6 17:24:45 2025 +0200

    cleanup

commit c06ff9c066
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 6 03:41:02 2025 +0200

    3D rope for controlnet

commit 948ea6b783
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 6 03:08:04 2025 +0200

    pose input scaling

commit 90c2eff3b2
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Dec 6 02:37:48 2025 +0200

    Cleanup

commit 9f7683422c
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Fri Dec 5 23:29:05 2025 +0200

    pose control

commit 0f217be4d8
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Fri Dec 5 20:55:10 2025 +0200

    Support reference input
2025-12-09 12:56:11 +02:00

237 lines
9.9 KiB
Python

# source https://github.com/TheDenk/wan2.1-dilated-controlnet/blob/main/wan_controlnet.py
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.transformers.transformer_wan import (
WanTimeTextImageEmbedding,
WanRotaryPosEmbed,
WanTransformerBlock
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanControlnet(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A Controlnet Transformer model for video-like data used in the Wan model.
Args:
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
num_attention_heads (`int`, defaults to `40`):
Fixed length for text embeddings.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
vae_channels (`int`, defaults to `16`):
The number of channels in the vae input.
in_channels (`int`, defaults to `16`):
The number of channels in the controlnet input.
text_dim (`int`, defaults to `512`):
Input dimension for text embeddings.
freq_dim (`int`, defaults to `256`):
Dimension for sinusoidal time embeddings.
ffn_dim (`int`, defaults to `13824`):
Intermediate dimension in feed-forward network.
num_layers (`int`, defaults to `40`):
The number of layers of transformer blocks to use.
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
Window size for local attention (-1 indicates global attention).
cross_attn_norm (`bool`, defaults to `True`):
Enable cross-attention normalization.
qk_norm (`bool`, defaults to `True`):
Enable query/key normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
add_img_emb (`bool`, defaults to `False`):
Whether to use img_emb.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
downscale_coef (`int`, *optional*, defaults to `8`):
Coeficient for downscale controlnet input video.
out_proj_dim (`int`, *optional*, defaults to `128 * 12`):
Output projection dimention for last linear layers.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 3,
vae_channels: int = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 13824,
num_layers: int = 20,
cross_attn_norm: bool = True,
qk_norm: Optional[str] = "rms_norm_across_heads",
eps: float = 1e-6,
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
downscale_coef: int = 8,
out_proj_dim: int = 128 * 12,
) -> None:
super().__init__()
start_channels = in_channels * (downscale_coef ** 2)
input_channels = [start_channels, start_channels // 2, start_channels // 4]
self.control_encoder = nn.ModuleList([
## Spatial compression with time awareness
nn.Sequential(
nn.Conv3d(
in_channels,
input_channels[0],
kernel_size=(3, downscale_coef + 1, downscale_coef + 1),
stride=(1, downscale_coef, downscale_coef),
padding=(1, downscale_coef // 2, downscale_coef // 2)
),
nn.GELU(approximate="tanh"),
nn.GroupNorm(2, input_channels[0]),
),
## Spatio-Temporal compression with spatial awareness
nn.Sequential(
nn.Conv3d(input_channels[0], input_channels[1], kernel_size=3, stride=(2, 1, 1), padding=1),
nn.GELU(approximate="tanh"),
nn.GroupNorm(2, input_channels[1]),
),
## Temporal compression with spatial awareness
nn.Sequential(
nn.Conv3d(input_channels[1], input_channels[2], kernel_size=3, stride=(2, 1, 1), padding=1),
nn.GELU(approximate="tanh"),
nn.GroupNorm(2, input_channels[2]),
)
])
inner_dim = num_attention_heads * attention_head_dim
# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = nn.Conv3d(vae_channels + input_channels[2], inner_dim, kernel_size=patch_size, stride=patch_size)
# 2. Condition embeddings
# image_embedding_dim=1280 for I2V model
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
)
# 3. Transformer blocks
self.blocks = nn.ModuleList(
[
WanTransformerBlock(
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
)
for _ in range(num_layers)
]
)
# 4 Controlnet modules
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.blocks)):
controlnet_block = nn.Linear(inner_dim, out_proj_dim)
self.controlnet_blocks.append(controlnet_block)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
controlnet_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
rotary_emb = self.rope(hidden_states)
# 0. Controlnet encoder
for control_encoder_block in self.control_encoder:
controlnet_states = control_encoder_block(controlnet_states)
hidden_states = torch.cat([hidden_states, controlnet_states], dim=1)
## 1. Patch embedding and stack
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
if timestep.ndim == 2:
## for ComfyUI workflow
if hidden_states.shape[1] != timestep.shape[1]:
timestep = timestep.repeat_interleave(hidden_states.shape[1] // timestep.shape[1], dim=1)
ts_seq_len = timestep.shape[1]
timestep = timestep.flatten() # batch_size * seq_len
else:
ts_seq_len = None
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
)
if ts_seq_len is not None:
# batch_size, seq_len, 6, inner_dim
timestep_proj = timestep_proj.unflatten(2, (6, -1))
else:
# batch_size, 6, inner_dim
timestep_proj = timestep_proj.unflatten(1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
# 4. Transformer blocks
controlnet_hidden_states = ()
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
)
controlnet_hidden_states += (controlnet_block(hidden_states),)
else:
for block, controlnet_block in zip(self.blocks, self.controlnet_blocks):
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
controlnet_hidden_states += (controlnet_block(hidden_states),)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_hidden_states,)
return Transformer2DModelOutput(sample=controlnet_hidden_states)