mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
mps cross-attention hack: don't crash on fp16 (#2258)
* mps cross-attention hack: don't crash on fp16 * Make conversion explicit.
This commit is contained in:
@@ -251,7 +251,7 @@ class CrossAttention(nn.Module):
|
||||
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
||||
# Instead, we can manually construct the padding tensor.
|
||||
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
||||
padding = torch.zeros(padding_shape, device=attention_mask.device)
|
||||
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
attention_mask = torch.concat([attention_mask, padding], dim=2)
|
||||
else:
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
|
||||
Reference in New Issue
Block a user