mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -1377,3 +1377,88 @@ class Expectations(DevicePropertiesUserDict):
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.data}"
|
||||
|
||||
|
||||
def dynamic_slice_test(func):
|
||||
"""
|
||||
Decorator that injects an expected_slice parameter into a test function.
|
||||
|
||||
On the first run, it will capture the actual slice output and cache it.
|
||||
On subsequent runs, it provides the cached slice as the expected slice.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@dynamic_slice_test
|
||||
def test_stable_diffusion_ddim(self, expected_slice=None):
|
||||
# Run the pipeline
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**components)
|
||||
inputs = self.get_dummy_inputs("cpu")
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
# If expected_slice is provided (from cache), assert against it
|
||||
if expected_slice is not None:
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
# Always return the current slice for caching
|
||||
return image_slice
|
||||
```
|
||||
"""
|
||||
# Check if the function has the expected_slice parameter
|
||||
sig = inspect.signature(func)
|
||||
if "expected_slice" not in sig.parameters:
|
||||
raise ValueError("The decorated function must have an 'expected_slice' parameter")
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Get the test name from pytest
|
||||
# pytest sets this environment variable to the current test
|
||||
test_name = os.environ.get("PYTEST_CURRENT_TEST", "")
|
||||
if test_name:
|
||||
# Format is: test_file.py::TestClass::test_method (call)
|
||||
test_name = test_name.split(" ")[0]
|
||||
else:
|
||||
# Fallback if not running in pytest
|
||||
test_name = f"{func.__module__}.{func.__qualname__}"
|
||||
|
||||
# Create a unique filename based on hardware details
|
||||
device_props = get_device_properties()
|
||||
device_str = f"{device_props[0]}{device_props[1] if device_props[1] is not None else ''}"
|
||||
|
||||
# Setup cache directory
|
||||
cache_dir = os.environ.get("DIFFUSERS_TEST_CACHE_DIR", ".test_cache")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
cache_path = os.path.join(cache_dir, f"{test_name}_{device_str}.npy")
|
||||
|
||||
# Check for cached expected slice
|
||||
cached_slice = None
|
||||
if os.path.exists(cache_path):
|
||||
try:
|
||||
cached_slice = np.load(cache_path)
|
||||
print(f"Using cached slice from {cache_path}")
|
||||
except Exception as e:
|
||||
print(f"Error loading cached slice: {e}")
|
||||
|
||||
# Run the test function with the expected slice injected
|
||||
kwargs["expected_slice"] = cached_slice
|
||||
actual_slice = func(*args, **kwargs)
|
||||
|
||||
# If the function returned a slice and there's no cached slice yet, cache it
|
||||
if actual_slice is not None and cached_slice is None:
|
||||
# Convert torch tensor to numpy if needed
|
||||
if hasattr(actual_slice, "detach") and hasattr(actual_slice, "cpu") and hasattr(actual_slice, "numpy"):
|
||||
actual_slice_np = actual_slice.detach().cpu().numpy()
|
||||
else:
|
||||
actual_slice_np = actual_slice
|
||||
|
||||
# Save the slice
|
||||
try:
|
||||
np.save(cache_path, actual_slice_np)
|
||||
print(f"Saved slice to cache: {cache_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving slice to cache: {e}")
|
||||
|
||||
return actual_slice
|
||||
|
||||
return wrapper
|
||||
|
||||
Reference in New Issue
Block a user