From e43244f33a137190a55c139119d4645808ee98d2 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 22 Nov 2022 00:47:22 +0100 Subject: [PATCH] Fix transpose issue --- src/diffusers/models/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index c302b053ad..fb0674051f 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -292,7 +292,7 @@ class AttentionBlock(nn.Module): key_states = self.transpose_for_scores(key_proj).transpose(3,2).contiguous().view(batch * self.num_heads, self.num_head_size, height * width) value_states = self.transpose_for_scores(value_proj).contiguous().view(batch * self.num_heads, height * width, self.num_head_size) else: - query_states, key_states, value_states = query_proj, key_proj, value_proj + query_states, key_states, value_states = query_proj, key_proj.transpose(-1, -2), value_proj attention_scores = torch.baddbmm( torch.empty( @@ -303,7 +303,7 @@ class AttentionBlock(nn.Module): device=query_states.device, ), query_states, - key_states.transpose(-1, -2), + key_states, beta=0, alpha=scale, )