1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix attention mask pad check (#3531)

This commit is contained in:
Will Berman
2023-05-23 13:11:53 -07:00
committed by GitHub
parent bde2cb5d9b
commit c13dbd5c3a

View File

@@ -381,12 +381,7 @@ class Attention(nn.Module):
return attention_mask
current_length: int = attention_mask.shape[-1]
if current_length > target_length:
# we *could* trim the mask with:
# attention_mask = attention_mask[:,:target_length]
# but this is weird enough that it's more likely to be a mistake than a shortcut
raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).")
elif current_length < target_length:
if current_length != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.