From 8a4c3e50bdae402482f49cc72c4f97e46ac083ee Mon Sep 17 00:00:00 2001 From: William Held Date: Tue, 27 Dec 2022 09:09:21 -0500 Subject: [PATCH] Width was typod as weight (#1800) * Width was typod as weight * Run Black --- src/diffusers/models/attention.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 9fe6a8034c..91c450d4a5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -204,17 +204,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin): """ # 1. Input if self.is_input_continuous: - batch, channel, height, weight = hidden_states.shape + batch, channel, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) @@ -231,15 +231,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin): # 3. Output if self.is_input_continuous: if not self.use_linear_projection: - hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual elif self.is_input_vectorized: