diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index c71e2a8336..7d8b962905 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -520,12 +520,12 @@ class CrossAttention(nn.Module): query = ( self.reshape_heads_to_batch_dim(query) .permute(0, 2, 1, 3) - .reshape(batch_size * head_size, seq_len, dim // head_size) + .reshape(batch_size * self.heads, seq_len, dim // self.heads) ) value = ( self.reshape_heads_to_batch_dim(value) .permute(0, 2, 1, 3) - .reshape(batch_size * head_size, seq_len, dim // head_size) + .reshape(batch_size * self.heads, seq_len, dim // self.heads) ) # TODO(PVP) - mask is currently never used. Remember to re-implement when used