1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[MPS] call contiguous after permute (#1411)

* call contiguous after permute

Fixes for MPS device

* Fix MPS UserWarning

* make style

* Revert "Fix MPS UserWarning"

This reverts commit b46c32810e.
This commit is contained in:
Kashif Rasul
2022-11-25 13:59:56 +01:00
committed by GitHub
parent 35099b207e
commit babfb8a020

View File

@@ -221,11 +221,15 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
output = hidden_states + residual
elif self.is_input_vectorized: