mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Revert "Save one more copy" as it's much slower on A100
This reverts commit 136f84283c.
This commit is contained in:
@@ -324,12 +324,15 @@ class AttentionBlock(nn.Module):
|
||||
attention_probs = torch.softmax(attention_scores, dim=-1, dtype=torch.float).type(attention_scores.dtype)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.bmm(attention_probs, value_states)
|
||||
if self.num_heads > 1:
|
||||
hidden_states = torch.empty((batch, height * width, self.num_heads, self.num_head_size), device=attention_probs.device, dtype=attention_probs.dtype)
|
||||
torch.bmm(attention_probs, value_states, out=hidden_states.view(batch * height * width, self.num_heads, self.num_head_size).permute(0, 2, 1, 3))
|
||||
hidden_states = hidden_states.view(batch, height * width, self.channels)
|
||||
else:
|
||||
hidden_states = torch.bmm(attention_probs, value_states)
|
||||
hidden_states = (
|
||||
hidden_states.view(batch, self.num_heads, height * width, self.num_head_size)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
)
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user