From fa4d738cbb9d33c5fc3ffe7c7f240b859a8d86e4 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Tue, 22 Nov 2022 11:53:58 +0100 Subject: [PATCH] Revert "Save one more copy" as it's much slower on A100 This reverts commit 136f84283c7f6d8ad7a20669ab1c99a50db6f83f. --- src/diffusers/models/attention.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2f2a94c06f..ad64e30d1f 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -324,12 +324,15 @@ class AttentionBlock(nn.Module): attention_probs = torch.softmax(attention_scores, dim=-1, dtype=torch.float).type(attention_scores.dtype) # compute attention output + hidden_states = torch.bmm(attention_probs, value_states) if self.num_heads > 1: - hidden_states = torch.empty((batch, height * width, self.num_heads, self.num_head_size), device=attention_probs.device, dtype=attention_probs.dtype) - torch.bmm(attention_probs, value_states, out=hidden_states.view(batch * height * width, self.num_heads, self.num_head_size).permute(0, 2, 1, 3)) - hidden_states = hidden_states.view(batch, height * width, self.channels) - else: - hidden_states = torch.bmm(attention_probs, value_states) + hidden_states = ( + hidden_states.view(batch, self.num_heads, height * width, self.num_head_size) + .permute(0, 2, 1, 3) + .contiguous() + ) + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) # compute next hidden_states hidden_states = self.proj_attn(hidden_states)