From a8c5801e26a4970ca587a29443cf04652fef30e2 Mon Sep 17 00:00:00 2001 From: DN6 Date: Wed, 10 Sep 2025 17:32:44 +0530 Subject: [PATCH] update --- tests/lora/utils.py | 58 ++++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index bae5886905..86401d0591 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -158,6 +158,14 @@ class PeftLoraLoaderMixinTests: cached_base_pipe_outs.update({scheduler_cls.__name__: output_no_lora}) setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs) + + def get_base_pipeline_output(self, scheduler_cls): + """ + Returns the cached base pipeline output for the given scheduler. + Properly handles accessing the class-level cache. + """ + cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {}) + return cached_base_pipe_outs[scheduler_cls.__name__] def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: @@ -350,7 +358,7 @@ class PeftLoraLoaderMixinTests: Tests a simple inference and makes sure it works as expected """ for scheduler_cls in self.scheduler_classes: - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) def test_simple_inference_with_text_lora(self): @@ -365,7 +373,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -448,7 +456,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -504,7 +512,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -542,7 +550,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -574,7 +582,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -609,7 +617,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -660,7 +668,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -711,7 +719,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -754,7 +762,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -795,7 +803,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -839,7 +847,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -877,7 +885,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -956,7 +964,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1085,7 +1093,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") @@ -1142,7 +1150,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1305,7 +1313,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1399,7 +1407,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1643,7 +1651,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1724,7 +1732,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1779,7 +1787,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_dora_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_dora_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_dora_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -1911,7 +1919,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - original_out = self.cached_base_pipe_outs[scheduler_cls.__name__] + original_out = self.get_base_pipeline_output(scheduler_cls) no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} logger = logging.get_logger("diffusers.loaders.peft") @@ -1957,7 +1965,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -2311,7 +2319,7 @@ class PeftLoraLoaderMixinTests: pipe = self.pipeline_class(**components).to(torch_device) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline( @@ -2361,7 +2369,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__] + output_no_lora = self.get_base_pipeline_output(scheduler_cls) if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config)