diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 9fe6a8034c..91c450d4a5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -204,17 +204,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin): """ # 1. Input if self.is_input_continuous: - batch, channel, height, weight = hidden_states.shape + batch, channel, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) @@ -231,15 +231,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin): # 3. Output if self.is_input_continuous: if not self.use_linear_projection: - hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual elif self.is_input_vectorized: