diff --git a/check_mako.py b/check_mako.py index c195107e2c..7e73c7295f 100644 --- a/check_mako.py +++ b/check_mako.py @@ -119,13 +119,15 @@ def get_prompts(): return prompt, negative_prompt +# Fixing batch size of 2 and `max_sequence_length` of 256 because of the kernels. def run_inference(pipeline, prompt, negative_prompt, num_inference_steps=50): output = pipeline( - prompt=prompt, + prompt=[prompt] * 2, negative_prompt=negative_prompt, num_frames=81, guidance_scale=5.0, num_inference_steps=num_inference_steps, + max_sequence_length=256, generator=torch.manual_seed(0) ).frames[0] return output diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index d7e6512af9..70619473aa 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -462,8 +462,10 @@ class WanTransformerBlock(nn.Module): temb: torch.Tensor, rotary_emb: torch.Tensor, ) -> torch.Tensor: - # Instead of performing the output projections on the attention outputs in the attention block + # Notes: Instead of performing the output projections on the attention outputs in the attention block # we perform them here to take advantage of fusion. + if not hidden_states.is_contiguous(): + hidden_states = hidden_states.contiguous() B, S, D = hidden_states.shape if temb.ndim == 4: # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) @@ -500,7 +502,6 @@ class WanTransformerBlock(nn.Module): attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) # hidden_states = hidden_states + attn_output # Fused Cross-Attn Output Proj + Residual (no gate) - print(f"{hidden_states.shape=}, {attn_output.shape=}, {self.attn2.to_out[0].weight.shape=}, {self.attn2.to_out[0].bias.shape=}") hidden_states = fused_matmul_residual( attn_output, self.attn2.to_out[0].weight, self.attn2.to_out[0].bias, hidden_states ) @@ -510,7 +511,6 @@ class WanTransformerBlock(nn.Module): # hidden_states # ) norm_hidden_states = triton_adaptive_norm(hidden_states, c_scale_msa, c_shift_msa, self.norm3.eps) - print(f"{norm_hidden_states.shape=}") # ff_output = self.ffn(norm_hidden_states) # hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) diff --git a/src/diffusers/models/transformers/wan_mako_attention_processor.py b/src/diffusers/models/transformers/wan_mako_attention_processor.py index 583d259951..9b244b81dc 100644 --- a/src/diffusers/models/transformers/wan_mako_attention_processor.py +++ b/src/diffusers/models/transformers/wan_mako_attention_processor.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F from typing import Optional, Tuple from .transformer_wan import WanAttention -from .wan_mako_kernels import triton_matmul, triton_rms_norm2, fused_matmul_residual +from .wan_mako_kernels import triton_matmul, triton_rms_norm2 # TODO: incorporate I2V support @@ -78,4 +78,4 @@ class WanMakoAttnProcessor: # (B, H, S, head_dim) -> (B, S, D) attn_out = attn_out.transpose(1, 2).reshape(B, S, D) - return attn_out + return attn_out.contiguous() if not attn_out.is_contiguous() else attn_out diff --git a/src/diffusers/models/transformers/wan_mako_kernels.py b/src/diffusers/models/transformers/wan_mako_kernels.py index 334c59cf0e..7a465e9388 100644 --- a/src/diffusers/models/transformers/wan_mako_kernels.py +++ b/src/diffusers/models/transformers/wan_mako_kernels.py @@ -587,7 +587,6 @@ class ModelNew(nn.Module): attn_out2 = attn_out2.transpose(1, 2).contiguous().view(B, S, D) # Fused Cross-Attn Output Proj + Residual (no gate) - print(f"{hidden_states.shape=}, {attn_out2.shape=}, {block.attn2.to_out[0].weight.shape=}, {block.attn2.to_out[0].bias.shape=}") hidden_states = fused_matmul_residual( attn_out2, block.attn2.to_out[0].weight, block.attn2.to_out[0].bias, hidden_states )