1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

mps test fixes (#2470)

* Skip variant tests (UNet1d, UNetRL) on mps.

mish op not yet supported.

* Exclude a couple of panorama tests on mps

They are too slow for fast CI.

* Exclude mps panorama from more tests.

* mps: exclude all fast panorama tests as they keep failing.
This commit is contained in:
Pedro Cuenca
2023-02-24 15:19:53 +01:00
committed by GitHub
parent 589faa8c88
commit 54bc882d96
2 changed files with 10 additions and 1 deletions

View File

@@ -66,6 +66,10 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
@@ -186,6 +190,10 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self):
super().test_model_from_pretrained()

View File

@@ -30,7 +30,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from ...test_pipelines_common import PipelineTesterMixin
@@ -38,6 +38,7 @@ from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
@skip_mps
class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableDiffusionPanoramaPipeline