mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -129,24 +129,17 @@ class PeftLoraLoaderMixinTests:
|
||||
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
cached_base_pipe_outputs = {}
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.get_base_pipeline_output()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
cls.cached_base_pipe_outputs.clear()
|
||||
|
||||
def get_base_pipeline_output(self):
|
||||
scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
|
||||
if self.cached_base_pipe_outputs and all(k in self.cached_base_pipe_outputs for k in scheduler_names):
|
||||
return
|
||||
cached_non_lora_outputs = {}
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def cache_non_lora_outputs(self):
|
||||
"""
|
||||
This fixture will be executed once per test class and will populate
|
||||
the cached_non_lora_outputs dictionary.
|
||||
"""
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
if scheduler_cls.__name__ in self.cached_base_pipe_outputs:
|
||||
# Check if the output for this scheduler is already cached to avoid re-running
|
||||
if scheduler_cls.__name__ in self.cached_non_lora_outputs:
|
||||
continue
|
||||
|
||||
components, _, _ = self.get_dummy_components(scheduler_cls)
|
||||
@@ -158,7 +151,11 @@ class PeftLoraLoaderMixinTests:
|
||||
# explicitly.
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.cached_base_pipe_outputs[scheduler_cls.__name__] = output_no_lora
|
||||
self.cached_non_lora_outputs[scheduler_cls.__name__] = output_no_lora
|
||||
|
||||
# Ensures that there's no inconsistency when reusing the cache.
|
||||
yield
|
||||
self.cached_non_lora_outputs.clear()
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
|
||||
if self.unet_kwargs and self.transformer_kwargs:
|
||||
@@ -351,7 +348,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
@@ -366,7 +363,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -449,7 +446,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -505,7 +502,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -543,7 +540,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -575,7 +572,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -610,7 +607,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -661,7 +658,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -712,7 +709,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -755,7 +752,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -796,7 +793,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -840,7 +837,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -878,7 +875,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -957,7 +954,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
@@ -1086,7 +1083,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
@@ -1143,7 +1140,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
@@ -1306,7 +1303,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
@@ -1400,7 +1397,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
@@ -1644,7 +1641,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -1725,7 +1722,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -1780,7 +1777,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_dora_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_dora_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -1912,7 +1909,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
original_out = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
original_out = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
|
||||
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
|
||||
logger = logging.get_logger("diffusers.loaders.peft")
|
||||
@@ -1958,7 +1955,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -2312,7 +2309,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
@@ -2362,7 +2359,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outputs[scheduler_cls.__name__]
|
||||
output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__]
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
|
||||
Reference in New Issue
Block a user