From e86a280c455130d597e352a6fe90367b14bfe925 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 7 Nov 2022 12:27:17 +0100 Subject: [PATCH] Remove warning about half precision on MPS (#1163) Remove warning about half precision on MPS. --- src/diffusers/pipeline_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 97e196e723..a708d0cfb5 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -209,13 +209,13 @@ class DiffusionPipeline(ConfigMixin): for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): - if module.dtype == torch.float16 and str(torch_device) in ["cpu", "mps"]: + if module.dtype == torch.float16 and str(torch_device) in ["cpu"]: logger.warning( - "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` or `mps` device. It" - " is not recommended to move them to `cpu` or `mps` as running them will fail. Please make" - " sure to use a `cuda` device to run the pipeline in inference. due to the lack of support for" - " `float16` operations on those devices in PyTorch. Please remove the" - " `torch_dtype=torch.float16` argument, or use a `cuda` device to run inference." + "Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It" + " is not recommended to move them to `cpu` as running them will fail. Please make" + " sure to use an accelerator to run the pipeline in inference, due to the lack of" + " support for`float16` operations on this device in PyTorch. Please, remove the" + " `torch_dtype=torch.float16` argument, or use another device for inference." ) module.to(torch_device) return self