From 2b4f849db9cfd73c5c367b2ac124c8e48ef32430 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Mon, 6 Mar 2023 00:36:47 +0100 Subject: [PATCH] [PipelineTesterMixin] Handle non-image outputs for attn slicing test (#2504) * [PipelineTesterMixin] Handle non-image outputs for batch/sinle inference test * style --------- Co-authored-by: William Berman --- tests/test_pipelines_common.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 8025bdd564..986770bede 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -450,7 +450,9 @@ class PipelineTesterMixin: def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass() - def _test_attention_slicing_forward_pass(self, test_max_difference=True, expected_max_diff=1e-3): + def _test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): if not self.test_attention_slicing: return @@ -474,7 +476,8 @@ class PipelineTesterMixin: max_diff = np.abs(output_with_slicing - output_without_slicing).max() self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results") - assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0]) + if test_mean_pixel_difference: + assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0]) @unittest.skipIf( torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),