1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Fix transpose issue

This commit is contained in:
thomasw21
2022-11-22 00:47:22 +01:00
parent 3c45926a0e
commit e43244f33a

View File

@@ -292,7 +292,7 @@ class AttentionBlock(nn.Module):
key_states = self.transpose_for_scores(key_proj).transpose(3,2).contiguous().view(batch * self.num_heads, self.num_head_size, height * width)
value_states = self.transpose_for_scores(value_proj).contiguous().view(batch * self.num_heads, height * width, self.num_head_size)
else:
query_states, key_states, value_states = query_proj, key_proj, value_proj
query_states, key_states, value_states = query_proj, key_proj.transpose(-1, -2), value_proj
attention_scores = torch.baddbmm(
torch.empty(
@@ -303,7 +303,7 @@ class AttentionBlock(nn.Module):
device=query_states.device,
),
query_states,
key_states.transpose(-1, -2),
key_states,
beta=0,
alpha=scale,
)