From d5af4fd1531666287717f0e07a25acfac42d7dbf Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 22 Nov 2022 12:21:59 +0100 Subject: [PATCH] Woops --- src/diffusers/models/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 200d108550..6fdb7b286f 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -532,12 +532,12 @@ class CrossAttention(nn.Module): # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: - key = key.permute(0, 2, 1, 3).reshape(batch_size * self.heads, sequence_length, dim // self.heads) + key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 1, 3).reshape(batch_size * self.heads, sequence_length, dim // self.heads) hidden_states = self._memory_efficient_attention_xformers(query, key, value) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: - key = key.permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, sequence_length) + key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, sequence_length) if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value) else: