1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Transformer2DModel] don't norm twice (#1381)

don't norm twice
This commit is contained in:
Suraj Patil
2022-11-24 00:12:45 +01:00
committed by GitHub
parent f07a16e09b
commit 1524122532

View File

@@ -201,13 +201,11 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)