mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -158,6 +158,14 @@ class PeftLoraLoaderMixinTests:
|
||||
cached_base_pipe_outs.update({scheduler_cls.__name__: output_no_lora})
|
||||
|
||||
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
|
||||
|
||||
def get_base_pipeline_output(self, scheduler_cls):
|
||||
"""
|
||||
Returns the cached base pipeline output for the given scheduler.
|
||||
Properly handles accessing the class-level cache.
|
||||
"""
|
||||
cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {})
|
||||
return cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
|
||||
if self.unet_kwargs and self.transformer_kwargs:
|
||||
@@ -350,7 +358,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
@@ -365,7 +373,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -448,7 +456,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -504,7 +512,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -542,7 +550,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -574,7 +582,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -609,7 +617,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -660,7 +668,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -711,7 +719,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
@@ -754,7 +762,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -795,7 +803,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -839,7 +847,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -877,7 +885,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -956,7 +964,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
@@ -1085,7 +1093,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
@@ -1142,7 +1150,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
@@ -1305,7 +1313,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
@@ -1399,7 +1407,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
@@ -1643,7 +1651,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -1724,7 +1732,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -1779,7 +1787,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_dora_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_dora_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -1911,7 +1919,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
original_out = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
original_out = self.get_base_pipeline_output(scheduler_cls)
|
||||
|
||||
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
|
||||
logger = logging.get_logger("diffusers.loaders.peft")
|
||||
@@ -1957,7 +1965,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
||||
@@ -2311,7 +2319,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.add_adapters_to_pipeline(
|
||||
@@ -2361,7 +2369,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.cached_base_pipe_outs[scheduler_cls.__name__]
|
||||
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
|
||||
Reference in New Issue
Block a user