mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Fix transpose issue
This commit is contained in:
@@ -292,7 +292,7 @@ class AttentionBlock(nn.Module):
|
||||
key_states = self.transpose_for_scores(key_proj).transpose(3,2).contiguous().view(batch * self.num_heads, self.num_head_size, height * width)
|
||||
value_states = self.transpose_for_scores(value_proj).contiguous().view(batch * self.num_heads, height * width, self.num_head_size)
|
||||
else:
|
||||
query_states, key_states, value_states = query_proj, key_proj, value_proj
|
||||
query_states, key_states, value_states = query_proj, key_proj.transpose(-1, -2), value_proj
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(
|
||||
@@ -303,7 +303,7 @@ class AttentionBlock(nn.Module):
|
||||
device=query_states.device,
|
||||
),
|
||||
query_states,
|
||||
key_states.transpose(-1, -2),
|
||||
key_states,
|
||||
beta=0,
|
||||
alpha=scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user