mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
mps test fixes (#2470)
* Skip variant tests (UNet1d, UNetRL) on mps. mish op not yet supported. * Exclude a couple of panorama tests on mps They are too slow for fast CI. * Exclude mps panorama from more tests. * mps: exclude all fast panorama tests as they keep failing.
This commit is contained in:
@@ -66,6 +66,10 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
@@ -186,6 +190,10 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -38,6 +38,7 @@ from ...test_pipelines_common import PipelineTesterMixin
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
@skip_mps
|
||||
class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionPanoramaPipeline
|
||||
|
||||
|
||||
Reference in New Issue
Block a user