1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
thomasw21
2022-11-22 12:21:59 +01:00
parent 1f135ac219
commit d5af4fd153

View File

@@ -532,12 +532,12 @@ class CrossAttention(nn.Module):
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
key = key.permute(0, 2, 1, 3).reshape(batch_size * self.heads, sequence_length, dim // self.heads)
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 1, 3).reshape(batch_size * self.heads, sequence_length, dim // self.heads)
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
key = key.permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, sequence_length)
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, sequence_length)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else: