diff --git a/tests/lora/utils.py b/tests/lora/utils.py index c4a7b433bd..332f8e10a3 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -129,20 +129,21 @@ class PeftLoraLoaderMixinTests: text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] - cached_base_pipe_outs = {} - def setUp(self): - self.get_base_pipe_outs() + self._cache_base_pipeline_output() super().setUp() - def get_base_pipe_outs(self): - cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {}) + def _cache_base_pipeline_output(self): + # Get or create the cache on the class (not instance) + if not hasattr(type(self), "cached_base_pipe_outs"): + setattr(type(self), "cached_base_pipe_outs", {}) + + cached_base_pipe_outs = type(self).cached_base_pipe_outs + all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes] - # Check if all required schedulers are already cached if cached_base_pipe_outs and all(k in cached_base_pipe_outs for k in all_scheduler_names): return - cached_base_pipe_outs = cached_base_pipe_outs or {} for scheduler_cls in self.scheduler_classes: if scheduler_cls.__name__ in cached_base_pipe_outs: continue @@ -156,21 +157,14 @@ class PeftLoraLoaderMixinTests: # explicitly. _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - cached_base_pipe_outs.update({scheduler_cls.__name__: output_no_lora}) - - setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs) - - def get_base_pipeline_output(self, scheduler_cls): - """ - Returns the cached base pipeline output for the given scheduler. - Properly handles accessing the class-level cache. - Ensures cache is populated if it hasn't been already. - """ - # Ensure cache is populated - self.get_base_pipe_outs() + cached_base_pipe_outs[scheduler_cls.__name__] = output_no_lora - cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {}) - return cached_base_pipe_outs[scheduler_cls.__name__] + # Update the class attribute + setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs) + + def get_base_pipeline_output(self, scheduler_cls): + self._cache_base_pipeline_output() + return type(self).cached_base_pipe_outs[scheduler_cls.__name__] def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: