From e619db24becce564a3646cef35ae83b06955867f Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 7 Feb 2023 19:51:33 +0100 Subject: [PATCH] mps cross-attention hack: don't crash on fp16 (#2258) * mps cross-attention hack: don't crash on fp16 * Make conversion explicit. --- src/diffusers/models/cross_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/cross_attention.py b/src/diffusers/models/cross_attention.py index 0602f2ee04..2ea2e7be58 100644 --- a/src/diffusers/models/cross_attention.py +++ b/src/diffusers/models/cross_attention.py @@ -251,7 +251,7 @@ class CrossAttention(nn.Module): # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) - padding = torch.zeros(padding_shape, device=attention_mask.device) + padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.concat([attention_mask, padding], dim=2) else: attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)