diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index ad64e30d1f..c71e2a8336 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -497,14 +497,13 @@ class CrossAttention(nn.Module): def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + tensor = tensor.view(batch_size, seq_len, head_size, dim // head_size) return tensor def reshape_batch_dim_to_heads(self, tensor): batch_size, seq_len, dim = tensor.shape head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.view(batch_size // head_size, head_size, seq_len, dim) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) return tensor @@ -518,18 +517,27 @@ class CrossAttention(nn.Module): dim = query.shape[-1] - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) + query = ( + self.reshape_heads_to_batch_dim(query) + .permute(0, 2, 1, 3) + .reshape(batch_size * head_size, seq_len, dim // head_size) + ) + value = ( + self.reshape_heads_to_batch_dim(value) + .permute(0, 2, 1, 3) + .reshape(batch_size * head_size, seq_len, dim // head_size) + ) # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: + key = key.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) hidden_states = self._memory_efficient_attention_xformers(query, key, value) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) else: + key = key.permute(0, 2, 3, 1).reshape(batch_size * head_size, dim // head_size, seq_len) if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value) else: @@ -543,9 +551,9 @@ class CrossAttention(nn.Module): def _attention(self, query, key, value): attention_scores = torch.baddbmm( - torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + torch.empty(query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device), query, - key.transpose(-1, -2), + key, beta=0, alpha=self.scale, ) @@ -568,9 +576,9 @@ class CrossAttention(nn.Module): start_idx = i * slice_size end_idx = (i + 1) * slice_size attn_slice = torch.baddbmm( - torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + torch.empty(slice_size, query.shape[1], key.shape[2], dtype=query.dtype, device=query.device), query[start_idx:end_idx], - key[start_idx:end_idx].transpose(-1, -2), + key[start_idx:end_idx], beta=0, alpha=self.scale, )