mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[DC-AE, SANA] fix SanaMultiscaleLinearAttention apply_quadratic_attention bf16 (#10595)
* autoencoder_dc tiling * add tiling and slicing support in SANA pipelines * create variables for padding length because the line becomes too long * add tiling and slicing support in pag SANA pipelines * revert changes to tile size * make style * add vae tiling test * fix SanaMultiscaleLinearAttention apply_quadratic_attention bf16 --------- Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
@@ -899,7 +899,7 @@ class SanaMultiscaleLinearAttention(nn.Module):
|
||||
scores = torch.matmul(key.transpose(-1, -2), query)
|
||||
scores = scores.to(dtype=torch.float32)
|
||||
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
||||
hidden_states = torch.matmul(value, scores)
|
||||
hidden_states = torch.matmul(value, scores.to(value.dtype))
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user