1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

mps cross-attention hack: don't crash on fp16 (#2258)

* mps cross-attention hack: don't crash on fp16

* Make conversion explicit.
This commit is contained in:
Pedro Cuenca
2023-02-07 19:51:33 +01:00
committed by GitHub
parent 111228cb39
commit e619db24be

View File

@@ -251,7 +251,7 @@ class CrossAttention(nn.Module):
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
padding = torch.zeros(padding_shape, device=attention_mask.device)
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.concat([attention_mask, padding], dim=2)
else:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)