mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Tests] Add is flaky decorator (#5139)
* add is flaky decorator * fix more
This commit is contained in:
committed by
GitHub
parent
787195fe20
commit
22b19d578e
@@ -1,3 +1,4 @@
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
@@ -7,7 +8,9 @@ import os
|
||||
import random
|
||||
import re
|
||||
import struct
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
import urllib.parse
|
||||
from contextlib import contextmanager
|
||||
@@ -612,6 +615,43 @@ def pytest_terminal_summary_main(tr, id):
|
||||
config.option.tbstyle = orig_tbstyle
|
||||
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
|
||||
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
|
||||
"""
|
||||
To decorate flaky tests. They will be retried on failures.
|
||||
|
||||
Args:
|
||||
max_attempts (`int`, *optional*, defaults to 5):
|
||||
The maximum number of attempts to retry the flaky test.
|
||||
wait_before_retry (`float`, *optional*):
|
||||
If provided, will wait that number of seconds before retrying the test.
|
||||
description (`str`, *optional*):
|
||||
A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors,
|
||||
etc.)
|
||||
"""
|
||||
|
||||
def decorator(test_func_ref):
|
||||
@functools.wraps(test_func_ref)
|
||||
def wrapper(*args, **kwargs):
|
||||
retry_count = 1
|
||||
|
||||
while retry_count < max_attempts:
|
||||
try:
|
||||
return test_func_ref(*args, **kwargs)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
|
||||
if wait_before_retry is not None:
|
||||
time.sleep(wait_before_retry)
|
||||
retry_count += 1
|
||||
|
||||
return test_func_ref(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787
|
||||
def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
|
||||
"""
|
||||
|
||||
@@ -30,6 +30,7 @@ from diffusers.utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
is_flaky,
|
||||
skip_mps,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -156,9 +157,14 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@is_flaky()
|
||||
def test_save_load_optional_components(self):
|
||||
super().test_save_load_optional_components(expected_max_difference=0.001)
|
||||
|
||||
@is_flaky()
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
super().test_dict_tuple_outputs_equivalent()
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
|
||||
Reference in New Issue
Block a user