mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix some audio tests (#3841)
* Fix some audio tests * make style * fix * make style
This commit is contained in:
committed by
GitHub
parent
5df2acf7d2
commit
5e3f8fff40
@@ -36,7 +36,7 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import slow, torch_device
|
||||
from diffusers.utils import is_xformers_available, slow, torch_device
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
|
||||
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
|
||||
@@ -361,9 +361,15 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(test_mean_pixel_difference=False)
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
|
||||
|
||||
|
||||
@slow
|
||||
# @require_torch_gpu
|
||||
class AudioLDMPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
@@ -640,7 +640,9 @@ class PipelineTesterMixin:
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
self._test_xformers_attention_forwardGenerator_pass()
|
||||
|
||||
def _test_xformers_attention_forwardGenerator_pass(self, test_max_difference=True, expected_max_diff=1e-4):
|
||||
def _test_xformers_attention_forwardGenerator_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-4
|
||||
):
|
||||
if not self.test_xformers_attention:
|
||||
return
|
||||
|
||||
@@ -660,7 +662,8 @@ class PipelineTesterMixin:
|
||||
max_diff = np.abs(output_with_offload - output_without_offload).max()
|
||||
self.assertLess(max_diff, expected_max_diff, "XFormers attention should not affect the inference results")
|
||||
|
||||
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
|
||||
if test_mean_pixel_difference:
|
||||
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
|
||||
|
||||
def test_progress_bar(self):
|
||||
components = self.get_dummy_components()
|
||||
|
||||
Reference in New Issue
Block a user