1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Sayak Paul
2025-12-11 04:47:57 +00:00
parent 2f947c423f
commit 2945a4fff7
4 changed files with 8 additions and 7 deletions

View File

@@ -119,13 +119,15 @@ def get_prompts():
return prompt, negative_prompt
# Fixing batch size of 2 and `max_sequence_length` of 256 because of the kernels.
def run_inference(pipeline, prompt, negative_prompt, num_inference_steps=50):
output = pipeline(
prompt=prompt,
prompt=[prompt] * 2,
negative_prompt=negative_prompt,
num_frames=81,
guidance_scale=5.0,
num_inference_steps=num_inference_steps,
max_sequence_length=256,
generator=torch.manual_seed(0)
).frames[0]
return output

View File

@@ -462,8 +462,10 @@ class WanTransformerBlock(nn.Module):
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
# Instead of performing the output projections on the attention outputs in the attention block
# Notes: Instead of performing the output projections on the attention outputs in the attention block
# we perform them here to take advantage of fusion.
if not hidden_states.is_contiguous():
hidden_states = hidden_states.contiguous()
B, S, D = hidden_states.shape
if temb.ndim == 4:
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
@@ -500,7 +502,6 @@ class WanTransformerBlock(nn.Module):
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
# hidden_states = hidden_states + attn_output
# Fused Cross-Attn Output Proj + Residual (no gate)
print(f"{hidden_states.shape=}, {attn_output.shape=}, {self.attn2.to_out[0].weight.shape=}, {self.attn2.to_out[0].bias.shape=}")
hidden_states = fused_matmul_residual(
attn_output, self.attn2.to_out[0].weight, self.attn2.to_out[0].bias, hidden_states
)
@@ -510,7 +511,6 @@ class WanTransformerBlock(nn.Module):
# hidden_states
# )
norm_hidden_states = triton_adaptive_norm(hidden_states, c_scale_msa, c_shift_msa, self.norm3.eps)
print(f"{norm_hidden_states.shape=}")
# ff_output = self.ffn(norm_hidden_states)
# hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)

View File

@@ -2,7 +2,7 @@ import torch
import torch.nn.functional as F
from typing import Optional, Tuple
from .transformer_wan import WanAttention
from .wan_mako_kernels import triton_matmul, triton_rms_norm2, fused_matmul_residual
from .wan_mako_kernels import triton_matmul, triton_rms_norm2
# TODO: incorporate I2V support
@@ -78,4 +78,4 @@ class WanMakoAttnProcessor:
# (B, H, S, head_dim) -> (B, S, D)
attn_out = attn_out.transpose(1, 2).reshape(B, S, D)
return attn_out
return attn_out.contiguous() if not attn_out.is_contiguous() else attn_out

View File

@@ -587,7 +587,6 @@ class ModelNew(nn.Module):
attn_out2 = attn_out2.transpose(1, 2).contiguous().view(B, S, D)
# Fused Cross-Attn Output Proj + Residual (no gate)
print(f"{hidden_states.shape=}, {attn_out2.shape=}, {block.attn2.to_out[0].weight.shape=}, {block.attn2.to_out[0].bias.shape=}")
hidden_states = fused_matmul_residual(
attn_out2, block.attn2.to_out[0].weight, block.attn2.to_out[0].bias, hidden_states
)