1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Do not use torch.long in mps (#1488)

* Do not use torch.long in mps

Addresses #1056.

* Use torch.int instead of float.

* Propagate changes.

* Do not silently change float -> int.

* Propagate changes.

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <anton@huggingface.co>

Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
Pedro Cuenca
2022-12-02 13:10:17 +01:00
committed by GitHub
parent a816a87a09
commit 3ceaa280bd
2 changed files with 16 additions and 4 deletions

View File

@@ -299,8 +299,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if torch.is_floating_point(timesteps):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML

View File

@@ -377,8 +377,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if torch.is_floating_point(timesteps):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML