diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index c302b053ad..fb0674051f 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -292,7 +292,7 @@ class AttentionBlock(nn.Module): key_states = self.transpose_for_scores(key_proj).transpose(3,2).contiguous().view(batch * self.num_heads, self.num_head_size, height * width) value_states = self.transpose_for_scores(value_proj).contiguous().view(batch * self.num_heads, height * width, self.num_head_size) else: - query_states, key_states, value_states = query_proj, key_proj, value_proj + query_states, key_states, value_states = query_proj, key_proj.transpose(-1, -2), value_proj attention_scores = torch.baddbmm( torch.empty( @@ -303,7 +303,7 @@ class AttentionBlock(nn.Module): device=query_states.device, ), query_states, - key_states.transpose(-1, -2), + key_states, beta=0, alpha=scale, )