diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 72c1dddaa2..7eaee021f5 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -129,6 +129,30 @@ class PeftLoraLoaderMixinTests: text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] + cached_non_lora_outputs = {} + + @pytest.fixture(scope="class", autouse=True) + def cache_non_lora_outputs(self, request): + """ + This fixture will be executed once per test class and will populate + the cached_non_lora_outputs dictionary. + """ + for scheduler_cls in self.scheduler_classes: + # Check if the output for this scheduler is already cached to avoid re-running + if scheduler_cls.__name__ in self.cached_non_lora_outputs: + continue + + components, _, _ = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # Always ensure the inputs are without the `generator`. Make sure to pass the `generator` + # explicitly. + _, _, inputs = self.get_dummy_inputs(with_generator=False) + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] + self.cached_non_lora_outputs[scheduler_cls.__name__] = output_no_lora + def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): if self.unet_kwargs and self.transformer_kwargs: raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") @@ -320,13 +344,7 @@ class PeftLoraLoaderMixinTests: Tests a simple inference and makes sure it works as expected """ for scheduler_cls in self.scheduler_classes: - components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) - pipe = self.pipeline_class(**components) - pipe = pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - - _, _, inputs = self.get_dummy_inputs() - output_no_lora = pipe(**inputs)[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) def test_simple_inference_with_text_lora(self): @@ -341,7 +359,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -424,7 +442,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -480,7 +498,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -518,7 +536,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -550,7 +568,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -585,7 +603,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -636,7 +654,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -687,7 +705,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) @@ -730,7 +748,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -771,7 +789,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -815,7 +833,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -853,7 +871,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -932,7 +950,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1061,7 +1079,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") @@ -1118,7 +1136,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1281,7 +1299,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1375,7 +1393,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") @@ -1619,7 +1637,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1700,7 +1718,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) if "text_encoder" in self.pipeline_class._lora_loadable_modules: @@ -1755,7 +1773,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_dora_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_dora_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -1887,7 +1905,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] + original_out = self.cached_non_lora_outputs[scheduler_cls.__name__] no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} logger = logging.get_logger("diffusers.loaders.peft") @@ -1933,7 +1951,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) @@ -2287,7 +2305,7 @@ class PeftLoraLoaderMixinTests: pipe = self.pipeline_class(**components).to(torch_device) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs(scheduler_cls.__name__) self.assertTrue(output_no_lora.shape == self.output_shape) pipe, _ = self.add_adapters_to_pipeline( @@ -2337,7 +2355,7 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config)