diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2f2a94c06f..ad64e30d1f 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -324,12 +324,15 @@ class AttentionBlock(nn.Module): attention_probs = torch.softmax(attention_scores, dim=-1, dtype=torch.float).type(attention_scores.dtype) # compute attention output + hidden_states = torch.bmm(attention_probs, value_states) if self.num_heads > 1: - hidden_states = torch.empty((batch, height * width, self.num_heads, self.num_head_size), device=attention_probs.device, dtype=attention_probs.dtype) - torch.bmm(attention_probs, value_states, out=hidden_states.view(batch * height * width, self.num_heads, self.num_head_size).permute(0, 2, 1, 3)) - hidden_states = hidden_states.view(batch, height * width, self.channels) - else: - hidden_states = torch.bmm(attention_probs, value_states) + hidden_states = ( + hidden_states.view(batch, self.num_heads, height * width, self.num_head_size) + .permute(0, 2, 1, 3) + .contiguous() + ) + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) # compute next hidden_states hidden_states = self.proj_attn(hidden_states)