diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index d002cb3315..d592848757 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -171,7 +171,7 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): hidden_states = self.proj_out(hidden_states) hidden_states = ( hidden_states[None, None, :] - .reshape(batch_size, height, width, channel, num_frames) + .reshape(batch_size, height, width, num_frames, channel) .permute(0, 3, 4, 1, 2) .contiguous() )