From babfb8a020778acffd48c5e08968c6570f02fa1d Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 25 Nov 2022 13:59:56 +0100 Subject: [PATCH] [MPS] call contiguous after permute (#1411) * call contiguous after permute Fixes for MPS device * Fix MPS UserWarning * make style * Revert "Fix MPS UserWarning" This reverts commit b46c32810ee5fdc4c16a8e9224a826490b66cf49. --- src/diffusers/models/attention.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 4c970d062d..e9454a467a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -221,11 +221,15 @@ class Transformer2DModel(ModelMixin, ConfigMixin): # 3. Output if self.is_input_continuous: if not self.use_linear_projection: - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) output = hidden_states + residual elif self.is_input_vectorized: