mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -119,13 +119,15 @@ def get_prompts():
|
||||
return prompt, negative_prompt
|
||||
|
||||
|
||||
# Fixing batch size of 2 and `max_sequence_length` of 256 because of the kernels.
|
||||
def run_inference(pipeline, prompt, negative_prompt, num_inference_steps=50):
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
prompt=[prompt] * 2,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=81,
|
||||
guidance_scale=5.0,
|
||||
num_inference_steps=num_inference_steps,
|
||||
max_sequence_length=256,
|
||||
generator=torch.manual_seed(0)
|
||||
).frames[0]
|
||||
return output
|
||||
|
||||
@@ -462,8 +462,10 @@ class WanTransformerBlock(nn.Module):
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Instead of performing the output projections on the attention outputs in the attention block
|
||||
# Notes: Instead of performing the output projections on the attention outputs in the attention block
|
||||
# we perform them here to take advantage of fusion.
|
||||
if not hidden_states.is_contiguous():
|
||||
hidden_states = hidden_states.contiguous()
|
||||
B, S, D = hidden_states.shape
|
||||
if temb.ndim == 4:
|
||||
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
|
||||
@@ -500,7 +502,6 @@ class WanTransformerBlock(nn.Module):
|
||||
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
|
||||
# hidden_states = hidden_states + attn_output
|
||||
# Fused Cross-Attn Output Proj + Residual (no gate)
|
||||
print(f"{hidden_states.shape=}, {attn_output.shape=}, {self.attn2.to_out[0].weight.shape=}, {self.attn2.to_out[0].bias.shape=}")
|
||||
hidden_states = fused_matmul_residual(
|
||||
attn_output, self.attn2.to_out[0].weight, self.attn2.to_out[0].bias, hidden_states
|
||||
)
|
||||
@@ -510,7 +511,6 @@ class WanTransformerBlock(nn.Module):
|
||||
# hidden_states
|
||||
# )
|
||||
norm_hidden_states = triton_adaptive_norm(hidden_states, c_scale_msa, c_shift_msa, self.norm3.eps)
|
||||
print(f"{norm_hidden_states.shape=}")
|
||||
# ff_output = self.ffn(norm_hidden_states)
|
||||
# hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, Tuple
|
||||
from .transformer_wan import WanAttention
|
||||
from .wan_mako_kernels import triton_matmul, triton_rms_norm2, fused_matmul_residual
|
||||
from .wan_mako_kernels import triton_matmul, triton_rms_norm2
|
||||
|
||||
|
||||
# TODO: incorporate I2V support
|
||||
@@ -78,4 +78,4 @@ class WanMakoAttnProcessor:
|
||||
# (B, H, S, head_dim) -> (B, S, D)
|
||||
attn_out = attn_out.transpose(1, 2).reshape(B, S, D)
|
||||
|
||||
return attn_out
|
||||
return attn_out.contiguous() if not attn_out.is_contiguous() else attn_out
|
||||
|
||||
@@ -587,7 +587,6 @@ class ModelNew(nn.Module):
|
||||
attn_out2 = attn_out2.transpose(1, 2).contiguous().view(B, S, D)
|
||||
|
||||
# Fused Cross-Attn Output Proj + Residual (no gate)
|
||||
print(f"{hidden_states.shape=}, {attn_out2.shape=}, {block.attn2.to_out[0].weight.shape=}, {block.attn2.to_out[0].bias.shape=}")
|
||||
hidden_states = fused_matmul_residual(
|
||||
attn_out2, block.attn2.to_out[0].weight, block.attn2.to_out[0].bias, hidden_states
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user