mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Woops
This commit is contained in:
@@ -520,12 +520,12 @@ class CrossAttention(nn.Module):
|
||||
query = (
|
||||
self.reshape_heads_to_batch_dim(query)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
.reshape(batch_size * self.heads, seq_len, dim // self.heads)
|
||||
)
|
||||
value = (
|
||||
self.reshape_heads_to_batch_dim(value)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
.reshape(batch_size * self.heads, seq_len, dim // self.heads)
|
||||
)
|
||||
|
||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||
|
||||
Reference in New Issue
Block a user