1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[tests] mark the wanvace lora tester flaky (#11883)

* mark wanvace lora tests as flaky

* ability to apply is_flaky at a class-level

* update

* increase max_attempt.

* increase attemtp.
This commit is contained in:
Sayak Paul
2025-07-09 13:27:15 +05:30
committed by GitHub
parent 737d7fc3b0
commit cc1f9a2ce3
2 changed files with 21 additions and 10 deletions

View File

@@ -994,10 +994,10 @@ def pytest_terminal_summary_main(tr, id):
config.option.tbstyle = orig_tbstyle
# Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
# Adapted from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
"""
To decorate flaky tests. They will be retried on failures.
To decorate flaky tests (methods or entire classes). They will be retried on failures.
Args:
max_attempts (`int`, *optional*, defaults to 5):
@@ -1009,22 +1009,33 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
etc.)
"""
def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def decorator(obj):
# If decorating a class, wrap each test method on it
if inspect.isclass(obj):
for attr_name, attr_value in list(obj.__dict__.items()):
if callable(attr_value) and attr_name.startswith("test"):
# recursively decorate the method
setattr(obj, attr_name, decorator(attr_value))
return obj
# Otherwise we're decorating a single test function / method
@functools.wraps(obj)
def wrapper(*args, **kwargs):
retry_count = 1
while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)
return obj(*args, **kwargs)
except Exception as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
msg = (
f"[FLAKY] {description or obj.__name__!r} "
f"failed on attempt {retry_count}/{max_attempts}: {err}"
)
print(msg, file=sys.stderr)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1
return test_func_ref(*args, **kwargs)
return obj(*args, **kwargs)
return wrapper

View File

@@ -46,6 +46,7 @@ from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
@is_flaky(max_attempts=10, description="very flaky class")
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
@@ -217,6 +218,5 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"Lora outputs should match.",
)
@is_flaky
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
super().test_simple_inference_with_text_denoiser_lora_and_scale()