diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 4d7ae6bef2..967ebf8649 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -899,7 +899,7 @@ class SanaMultiscaleLinearAttention(nn.Module): scores = torch.matmul(key.transpose(-1, -2), query) scores = scores.to(dtype=torch.float32) scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps) - hidden_states = torch.matmul(value, scores) + hidden_states = torch.matmul(value, scores.to(value.dtype)) return hidden_states def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: