diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index eddf196718..2c8b76912b 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -325,7 +325,7 @@ class WanTransformerBlock(nn.Module): c_scale_msa = c_scale_msa.squeeze(2) c_gate_msa = c_gate_msa.squeeze(2) else: - # temb: batch_size, 6, inner_dim + # temb: batch_size, 6, inner_dim (wan2.1) shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( self.scale_shift_table + temb.float() ).chunk(6, dim=1)