diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 7eaee021f5..00831e3888 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -132,7 +132,7 @@ class PeftLoraLoaderMixinTests: cached_non_lora_outputs = {} @pytest.fixture(scope="class", autouse=True) - def cache_non_lora_outputs(self, request): + def cache_non_lora_outputs(self): """ This fixture will be executed once per test class and will populate the cached_non_lora_outputs dictionary. @@ -150,9 +150,13 @@ class PeftLoraLoaderMixinTests: # Always ensure the inputs are without the `generator`. Make sure to pass the `generator` # explicitly. _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.cached_non_lora_outputs[scheduler_cls.__name__] = output_no_lora + # Ensures that there's no inconsistency when reusing the cache. + yield + self.cached_non_lora_outputs.clear() + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")