1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[PipelineTesterMixin] Handle non-image outputs for attn slicing test (#2504)

* [PipelineTesterMixin] Handle non-image outputs for batch/sinle inference test

* style

---------

Co-authored-by: William Berman <WLBberman@gmail.com>
This commit is contained in:
Sanchit Gandhi
2023-03-06 00:36:47 +01:00
committed by GitHub
parent e4c356d3f6
commit 2b4f849db9

View File

@@ -450,7 +450,9 @@ class PipelineTesterMixin:
def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass()
def _test_attention_slicing_forward_pass(self, test_max_difference=True, expected_max_diff=1e-3):
def _test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
@@ -474,7 +476,8 @@ class PipelineTesterMixin:
max_diff = np.abs(output_with_slicing - output_without_slicing).max()
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
@unittest.skipIf(
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),