1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

removing all reshapes to test perf

This commit is contained in:
Nouamane Tazi
2022-09-21 15:22:59 +00:00
parent c0dd0e90e8
commit 006ccb8a8c

View File

@@ -91,7 +91,7 @@ class AttentionBlock(nn.Module):
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
hidden_states = hidden_states.reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
@@ -150,10 +150,10 @@ class SpatialTransformer(nn.Module):
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
hidden_states = hidden_states.reshape(batch, height * weight, channel)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=context)
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
hidden_states = hidden_states.reshape(batch, channel, height, weight)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual
@@ -262,9 +262,9 @@ class CrossAttention(nn.Module):
key = self.to_k(context)
value = self.to_v(context)
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
# query = self.reshape_heads_to_batch_dim(query)
# key = self.reshape_heads_to_batch_dim(key)
# value = self.reshape_heads_to_batch_dim(value)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
@@ -290,7 +290,7 @@ class CrossAttention(nn.Module):
# compute attention output
hidden_states = torch.matmul(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _sliced_attention(self, query, key, value, sequence_length, dim):
@@ -309,7 +309,7 @@ class CrossAttention(nn.Module):
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states