diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 8025bdd564..986770bede 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -450,7 +450,9 @@ class PipelineTesterMixin: def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass() - def _test_attention_slicing_forward_pass(self, test_max_difference=True, expected_max_diff=1e-3): + def _test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): if not self.test_attention_slicing: return @@ -474,7 +476,8 @@ class PipelineTesterMixin: max_diff = np.abs(output_with_slicing - output_without_slicing).max() self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results") - assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0]) + if test_mean_pixel_difference: + assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0]) @unittest.skipIf( torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),