mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[MPS] call contiguous after permute (#1411)
* call contiguous after permute
Fixes for MPS device
* Fix MPS UserWarning
* make style
* Revert "Fix MPS UserWarning"
This reverts commit b46c32810e.
This commit is contained in:
@@ -221,11 +221,15 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
|
||||
Reference in New Issue
Block a user