1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Remove transpose for baddbmm

This commit is contained in:
thomasw21
2022-11-22 12:17:11 +01:00
parent 31d26872c1
commit 5d4145cfa2

View File

@@ -497,14 +497,13 @@ class CrossAttention(nn.Module):
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
tensor = tensor.view(batch_size, seq_len, head_size, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.view(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
@@ -518,18 +517,27 @@ class CrossAttention(nn.Module):
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
query = (
self.reshape_heads_to_batch_dim(query)
.permute(0, 2, 1, 3)
.reshape(batch_size * head_size, seq_len, dim // head_size)
)
value = (
self.reshape_heads_to_batch_dim(value)
.permute(0, 2, 1, 3)
.reshape(batch_size * head_size, seq_len, dim // head_size)
)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
key = key.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
key = key.permute(0, 2, 3, 1).reshape(batch_size * head_size, dim // head_size, seq_len)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
@@ -543,9 +551,9 @@ class CrossAttention(nn.Module):
def _attention(self, query, key, value):
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
torch.empty(query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
key,
beta=0,
alpha=self.scale,
)
@@ -568,9 +576,9 @@ class CrossAttention(nn.Module):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
torch.empty(slice_size, query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
query[start_idx:end_idx],
key[start_idx:end_idx].transpose(-1, -2),
key[start_idx:end_idx],
beta=0,
alpha=self.scale,
)