diff --git a/tests/lora/utils.py b/tests/lora/utils.py index c9ea203b52..0fb266e918 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -129,17 +129,24 @@ class PeftLoraLoaderMixinTests: text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] - cached_non_lora_outputs = {} + cached_base_pipe_outputs = {} + + def setUp(self): + super().setUp() + self.get_base_pipeline_output() + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls.cached_base_pipe_outputs.clear() + + def get_base_pipeline_output(self): + scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes] + if self.cached_base_pipe_outputs and all(k in self.cached_base_pipe_outputs for k in scheduler_names): + return - @pytest.fixture(scope="class", autouse=True) - def cache_non_lora_outputs(self): - """ - This fixture will be executed once per test class and will populate - the cached_non_lora_outputs dictionary. - """ for scheduler_cls in self.scheduler_classes: - # Check if the output for this scheduler is already cached to avoid re-running - if scheduler_cls.__name__ in self.cached_non_lora_outputs: + if scheduler_cls.__name__ in self.cached_base_pipe_outputs: continue components, _, _ = self.get_dummy_components(scheduler_cls) @@ -151,11 +158,7 @@ class PeftLoraLoaderMixinTests: # explicitly. _, _, inputs = self.get_dummy_inputs(with_generator=False) output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.cached_non_lora_outputs[scheduler_cls.__name__] = output_no_lora - - # Ensures that there's no inconsistency when reusing the cache. - yield - self.cached_non_lora_outputs.clear() + self.cached_base_pipe_outputs[scheduler_cls.__name__] = output_no_lora def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: @@ -348,7 +351,7 @@ class PeftLoraLoaderMixinTests: Tests a simple inference and makes sure it works as expected """ for scheduler_cls in self.scheduler_classes: - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) def test_simple_inference_with_text_lora(self): @@ -363,7 +366,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -446,7 +449,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -502,7 +505,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -540,7 +543,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -572,7 +575,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -607,7 +610,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -658,7 +661,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -709,7 +712,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -752,7 +755,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -793,7 +796,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -837,7 +840,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -875,7 +878,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -954,7 +957,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1083,7 +1086,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") @@ -1140,7 +1143,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1303,7 +1306,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1397,7 +1400,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1641,7 +1644,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1722,7 +1725,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1777,7 +1780,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_dora_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_dora_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_dora_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -1909,7 +1912,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - original_out = self.cached_non_lora_outputs[scheduler_cls.__name__] + original_out = self.cached_base_pipe_outputs[scheduler_cls.__name__] no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} logger = logging.get_logger("diffusers.loaders.peft") @@ -1955,7 +1958,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -2309,7 +2312,7 @@ class PeftLoraLoaderMixinTests: pipe = self.pipeline_class(**components).to(torch_device) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline( @@ -2359,7 +2362,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__] if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config)