diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 332f8e10a3..ff4f4002d2 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -137,11 +137,12 @@ class PeftLoraLoaderMixinTests: # Get or create the cache on the class (not instance) if not hasattr(type(self), "cached_base_pipe_outs"): setattr(type(self), "cached_base_pipe_outs", {}) - + cached_base_pipe_outs = type(self).cached_base_pipe_outs - + all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes] if cached_base_pipe_outs and all(k in cached_base_pipe_outs for k in all_scheduler_names): + __import__("ipdb").set_trace() return for scheduler_cls in self.scheduler_classes: @@ -158,12 +159,15 @@ class PeftLoraLoaderMixinTests: _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] cached_base_pipe_outs[scheduler_cls.__name__] = output_no_lora - + # Update the class attribute setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs) def get_base_pipeline_output(self, scheduler_cls): - self._cache_base_pipeline_output() + """ + Returns the cached base pipeline output for the given scheduler. + Cache is populated during setUp, so this just retrieves the value. + """ return type(self).cached_base_pipe_outs[scheduler_cls.__name__] def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):