1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Revert "Save one more copy" as it's much slower on A100

This reverts commit 136f84283c.
This commit is contained in:
thomasw21
2022-11-22 11:53:58 +01:00
parent 136f84283c
commit fa4d738cbb

View File

@@ -324,12 +324,15 @@ class AttentionBlock(nn.Module):
attention_probs = torch.softmax(attention_scores, dim=-1, dtype=torch.float).type(attention_scores.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value_states)
if self.num_heads > 1:
hidden_states = torch.empty((batch, height * width, self.num_heads, self.num_head_size), device=attention_probs.device, dtype=attention_probs.dtype)
torch.bmm(attention_probs, value_states, out=hidden_states.view(batch * height * width, self.num_heads, self.num_head_size).permute(0, 2, 1, 3))
hidden_states = hidden_states.view(batch, height * width, self.channels)
else:
hidden_states = torch.bmm(attention_probs, value_states)
hidden_states = (
hidden_states.view(batch, self.num_heads, height * width, self.num_head_size)
.permute(0, 2, 1, 3)
.contiguous()
)
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)