mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Tests] Cleanup lora tests utils (#11276)
* start cleaning up lora test utils for reusability * update * updates * updates
This commit is contained in:
@@ -260,6 +260,31 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
return modules_to_save
|
||||
|
||||
def check_if_adapters_added_correctly(
|
||||
self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
|
||||
):
|
||||
if text_lora_config is not None:
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
if denoiser_lora_config is not None:
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
else:
|
||||
denoiser = None
|
||||
|
||||
if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
return pipe, denoiser
|
||||
|
||||
def test_simple_inference(self):
|
||||
"""
|
||||
Tests a simple inference and makes sure it works as expected
|
||||
@@ -289,16 +314,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
@@ -381,22 +397,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -459,16 +460,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
@@ -506,15 +498,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
pipe.fuse_lora()
|
||||
# Fusing should still keep the LoRA layers
|
||||
@@ -546,19 +530,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
# unloading should remove the LoRA layers
|
||||
@@ -593,18 +565,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -655,22 +616,20 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
|
||||
# supports missing layers (PR#8324).
|
||||
state_dict = {
|
||||
f"text_encoder.{module_name}": param
|
||||
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
|
||||
if "text_model.encoder.layers.4" not in module_name
|
||||
}
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
state_dict = {}
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
|
||||
# supports missing layers (PR#8324).
|
||||
state_dict = {
|
||||
f"text_encoder.{module_name}": param
|
||||
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
|
||||
if "text_model.encoder.layers.4" not in module_name
|
||||
}
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
state_dict.update(
|
||||
{
|
||||
f"text_encoder_2.{module_name}": param
|
||||
@@ -694,7 +653,7 @@ class PeftLoraLoaderMixinTests:
|
||||
"Removing adapters should change the output",
|
||||
)
|
||||
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
def test_simple_inference_save_pretrained_with_text_lora(self):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
@@ -708,16 +667,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -726,10 +676,11 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -759,22 +710,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -820,22 +756,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
@@ -879,22 +800,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
|
||||
|
||||
@@ -932,22 +838,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
# unloading should remove the LoRA layers
|
||||
@@ -983,22 +874,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
|
||||
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -1104,6 +980,8 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
def test_wrong_adapter_name_raises_error(self):
|
||||
adapter_name = "adapter-1"
|
||||
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -1111,20 +989,9 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as err_context:
|
||||
pipe.set_adapters("test")
|
||||
@@ -1132,10 +999,11 @@ class PeftLoraLoaderMixinTests:
|
||||
self.assertTrue("not in the list of present adapters" in str(err_context.exception))
|
||||
|
||||
# test this works.
|
||||
pipe.set_adapters("adapter-1")
|
||||
pipe.set_adapters(adapter_name)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_multiple_wrong_adapter_name_raises_error(self):
|
||||
adapter_name = "adapter-1"
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -1143,33 +1011,22 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
|
||||
)
|
||||
|
||||
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
|
||||
logger = logging.get_logger("diffusers.loaders.lora_base")
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.set_adapters("adapter-1", adapter_weights=scale_with_wrong_components)
|
||||
pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
|
||||
|
||||
wrong_components = sorted(set(scale_with_wrong_components.keys()))
|
||||
msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
|
||||
self.assertTrue(msg in str(cap_logger.out))
|
||||
|
||||
# test this works.
|
||||
pipe.set_adapters("adapter-1")
|
||||
pipe.set_adapters(adapter_name)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
@@ -1804,20 +1661,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -1908,18 +1752,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
@@ -2011,22 +1844,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
lora_scale = 0.5
|
||||
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
|
||||
@@ -2211,22 +2029,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe = pipe.to(torch_device, dtype=compute_dtype)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
if storage_dtype is not None:
|
||||
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
|
||||
Reference in New Issue
Block a user