mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
removing all reshapes to test perf
This commit is contained in:
@@ -91,7 +91,7 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
hidden_states = hidden_states.reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
@@ -150,10 +150,10 @@ class SpatialTransformer(nn.Module):
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
|
||||
hidden_states = hidden_states.reshape(batch, height * weight, channel)
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, context=context)
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
|
||||
hidden_states = hidden_states.reshape(batch, channel, height, weight)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states + residual
|
||||
|
||||
@@ -262,9 +262,9 @@ class CrossAttention(nn.Module):
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
# query = self.reshape_heads_to_batch_dim(query)
|
||||
# key = self.reshape_heads_to_batch_dim(key)
|
||||
# value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||
|
||||
@@ -290,7 +290,7 @@ class CrossAttention(nn.Module):
|
||||
# compute attention output
|
||||
hidden_states = torch.matmul(attention_probs, value)
|
||||
# reshape hidden_states
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _sliced_attention(self, query, key, value, sequence_length, dim):
|
||||
@@ -309,7 +309,7 @@ class CrossAttention(nn.Module):
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
# reshape hidden_states
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user