mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[MPS] Make sure it doesn't break torch < 1.12 (#425)
* [MPS] Make sure it doesn't break torch < 1.12 * up
This commit is contained in:
committed by
GitHub
parent
8d9c4a531b
commit
f8325cfd7b
@@ -5,10 +5,15 @@ from distutils.util import strtobool
|
||||
|
||||
import torch
|
||||
|
||||
from packaging import version
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
|
||||
is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12")
|
||||
|
||||
if is_torch_higher_equal_than_1_12:
|
||||
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
|
||||
Reference in New Issue
Block a user