From b785ddb654e4be3ae0066e231734754bdb2a191c Mon Sep 17 00:00:00 2001 From: Junyu Chen <70215701+chenjy2003@users.noreply.github.com> Date: Thu, 16 Jan 2025 19:19:02 +0800 Subject: [PATCH] [DC-AE, SANA] fix SanaMultiscaleLinearAttention apply_quadratic_attention bf16 (#10595) * autoencoder_dc tiling * add tiling and slicing support in SANA pipelines * create variables for padding length because the line becomes too long * add tiling and slicing support in pag SANA pipelines * revert changes to tile size * make style * add vae tiling test * fix SanaMultiscaleLinearAttention apply_quadratic_attention bf16 --------- Co-authored-by: Aryan --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4d7ae6bef2..967ebf8649 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -899,7 +899,7 @@ class SanaMultiscaleLinearAttention(nn.Module): scores = torch.matmul(key.transpose(-1, -2), query) scores = scores.to(dtype=torch.float32) scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps) - hidden_states = torch.matmul(value, scores) + hidden_states = torch.matmul(value, scores.to(value.dtype)) return hidden_states def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: