diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 6fdb7b286f..d0a7ab3c69 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -512,6 +512,7 @@ class CrossAttention(nn.Module): query = self.to_q(hidden_states) context = context if context is not None else hidden_states + context_sequence_length = context.shape[1] key = self.to_k(context) value = self.to_v(context) @@ -525,19 +526,19 @@ class CrossAttention(nn.Module): value = ( self.reshape_heads_to_batch_dim(value) .permute(0, 2, 1, 3) - .reshape(batch_size * self.heads, sequence_length, dim // self.heads) + .reshape(batch_size * self.heads, context_sequence_length, dim // self.heads) ) # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: - key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 1, 3).reshape(batch_size * self.heads, sequence_length, dim // self.heads) + key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 1, 3).reshape(batch_size * self.heads, context_sequence_length, dim // self.heads) hidden_states = self._memory_efficient_attention_xformers(query, key, value) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: - key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, sequence_length) + key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, context_sequence_length) if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value) else: