mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix attention mask pad check (#3531)
This commit is contained in:
@@ -381,12 +381,7 @@ class Attention(nn.Module):
|
||||
return attention_mask
|
||||
|
||||
current_length: int = attention_mask.shape[-1]
|
||||
if current_length > target_length:
|
||||
# we *could* trim the mask with:
|
||||
# attention_mask = attention_mask[:,:target_length]
|
||||
# but this is weird enough that it's more likely to be a mistake than a shortcut
|
||||
raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).")
|
||||
elif current_length < target_length:
|
||||
if current_length != target_length:
|
||||
if attention_mask.device.type == "mps":
|
||||
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
||||
# Instead, we can manually construct the padding tensor.
|
||||
|
||||
Reference in New Issue
Block a user