mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[PipelineTesterMixin] Handle non-image outputs for attn slicing test (#2504)
* [PipelineTesterMixin] Handle non-image outputs for batch/sinle inference test * style --------- Co-authored-by: William Berman <WLBberman@gmail.com>
This commit is contained in:
@@ -450,7 +450,9 @@ class PipelineTesterMixin:
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass()
|
||||
|
||||
def _test_attention_slicing_forward_pass(self, test_max_difference=True, expected_max_diff=1e-3):
|
||||
def _test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
@@ -474,7 +476,8 @@ class PipelineTesterMixin:
|
||||
max_diff = np.abs(output_with_slicing - output_without_slicing).max()
|
||||
self.assertLess(max_diff, expected_max_diff, "Attention slicing should not affect the inference results")
|
||||
|
||||
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
|
||||
if test_mean_pixel_difference:
|
||||
assert_mean_pixel_difference(output_with_slicing[0], output_without_slicing[0])
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"),
|
||||
|
||||
Reference in New Issue
Block a user