diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02454d6036..f3623c6e7e 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -91,7 +91,7 @@ class AttentionBlock(nn.Module): # compute next hidden_states hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + hidden_states = hidden_states.reshape(batch, channel, height, width) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor @@ -150,10 +150,10 @@ class SpatialTransformer(nn.Module): residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) + hidden_states = hidden_states.reshape(batch, height * weight, channel) for block in self.transformer_blocks: hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) + hidden_states = hidden_states.reshape(batch, channel, height, weight) hidden_states = self.proj_out(hidden_states) return hidden_states + residual @@ -262,9 +262,9 @@ class CrossAttention(nn.Module): key = self.to_k(context) value = self.to_v(context) - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) + # query = self.reshape_heads_to_batch_dim(query) + # key = self.reshape_heads_to_batch_dim(key) + # value = self.reshape_heads_to_batch_dim(value) # TODO(PVP) - mask is currently never used. Remember to re-implement when used @@ -290,7 +290,7 @@ class CrossAttention(nn.Module): # compute attention output hidden_states = torch.matmul(attention_probs, value) # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states def _sliced_attention(self, query, key, value, sequence_length, dim): @@ -309,7 +309,7 @@ class CrossAttention(nn.Module): hidden_states[start_idx:end_idx] = attn_slice # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states