mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Remove transpose for baddbmm
This commit is contained in:
@@ -497,14 +497,13 @@ class CrossAttention(nn.Module):
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
tensor = tensor.view(batch_size, seq_len, head_size, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.view(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
@@ -518,18 +517,27 @@ class CrossAttention(nn.Module):
|
||||
|
||||
dim = query.shape[-1]
|
||||
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
query = (
|
||||
self.reshape_heads_to_batch_dim(query)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
)
|
||||
value = (
|
||||
self.reshape_heads_to_batch_dim(value)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
)
|
||||
|
||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
key = key.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
|
||||
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
else:
|
||||
key = key.permute(0, 2, 3, 1).reshape(batch_size * head_size, dim // head_size, seq_len)
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value)
|
||||
else:
|
||||
@@ -543,9 +551,9 @@ class CrossAttention(nn.Module):
|
||||
|
||||
def _attention(self, query, key, value):
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
key,
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
@@ -568,9 +576,9 @@ class CrossAttention(nn.Module):
|
||||
start_idx = i * slice_size
|
||||
end_idx = (i + 1) * slice_size
|
||||
attn_slice = torch.baddbmm(
|
||||
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
torch.empty(slice_size, query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx].transpose(-1, -2),
|
||||
key[start_idx:end_idx],
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user