mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
use torch.matmul instead of einsum in attnetion. (#445)
* use torch.matmul instead of einsum * fix softmax
This commit is contained in:
@@ -275,11 +275,9 @@ class CrossAttention(nn.Module):
|
||||
for i in range(hidden_states.shape[0] // slice_size):
|
||||
start_idx = i * slice_size
|
||||
end_idx = (i + 1) * slice_size
|
||||
attn_slice = (
|
||||
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
|
||||
)
|
||||
attn_slice = torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
|
||||
attn_slice = attn_slice.softmax(dim=-1)
|
||||
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
|
||||
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
|
||||
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
|
||||
Reference in New Issue
Block a user