From f8325cfd7ba0f024d590bd466a994099467dcc13 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 8 Sep 2022 16:22:23 +0200 Subject: [PATCH] [MPS] Make sure it doesn't break torch < 1.12 (#425) * [MPS] Make sure it doesn't break torch < 1.12 * up --- src/diffusers/testing_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffusers/testing_utils.py b/src/diffusers/testing_utils.py index a1288b4edb..ff8b6aa9b4 100644 --- a/src/diffusers/testing_utils.py +++ b/src/diffusers/testing_utils.py @@ -5,10 +5,15 @@ from distutils.util import strtobool import torch +from packaging import version + global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" -torch_device = "mps" if torch.backends.mps.is_available() else torch_device +is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") + +if is_torch_higher_equal_than_1_12: + torch_device = "mps" if torch.backends.mps.is_available() else torch_device def parse_flag_from_env(key, default=False):