1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix some audio tests (#3841)

* Fix some audio tests

* make style

* fix

* make style
This commit is contained in:
Patrick von Platen
2023-06-22 13:53:27 +02:00
committed by GitHub
parent 5df2acf7d2
commit 5e3f8fff40
2 changed files with 13 additions and 4 deletions

View File

@@ -36,7 +36,7 @@ from diffusers import (
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import slow, torch_device
from diffusers.utils import is_xformers_available, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
@@ -361,9 +361,15 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(test_mean_pixel_difference=False)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
@slow
# @require_torch_gpu
class AudioLDMPipelineSlowTests(unittest.TestCase):
def tearDown(self):
super().tearDown()

View File

@@ -640,7 +640,9 @@ class PipelineTesterMixin:
def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass()
def _test_xformers_attention_forwardGenerator_pass(self, test_max_difference=True, expected_max_diff=1e-4):
def _test_xformers_attention_forwardGenerator_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-4
):
if not self.test_xformers_attention:
return
@@ -660,7 +662,8 @@ class PipelineTesterMixin:
max_diff = np.abs(output_with_offload - output_without_offload).max()
self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
def test_progress_bar(self):
components = self.get_dummy_components()