1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[MPS] Make sure it doesn't break torch < 1.12 (#425)

* [MPS] Make sure it doesn't break torch < 1.12

* up
This commit is contained in:
Patrick von Platen
2022-09-08 16:22:23 +02:00
committed by GitHub
parent 8d9c4a531b
commit f8325cfd7b

View File

@@ -5,10 +5,15 @@ from distutils.util import strtobool
import torch
from packaging import version
global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
if is_torch_higher_equal_than_1_12:
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
def parse_flag_from_env(key, default=False):