From 3ceaa280bd1550bc17cd8268cc34278b7f0b9070 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 2 Dec 2022 13:10:17 +0100 Subject: [PATCH] Do not use torch.long in mps (#1488) * Do not use torch.long in mps Addresses #1056. * Use torch.int instead of float. * Propagate changes. * Do not silently change float -> int. * Propagate changes. * Apply suggestions from code review Co-authored-by: Anton Lozhkov Co-authored-by: Anton Lozhkov --- src/diffusers/models/unet_2d_condition.py | 10 ++++++++-- .../versatile_diffusion/modeling_text_unet.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 43f032729b..4bae7b9fa6 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -299,8 +299,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if torch.is_floating_point(timesteps): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index e782274d82..7637339a84 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -377,8 +377,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if torch.is_floating_point(timesteps): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML