mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
revert lora utils changes
This commit is contained in:
@@ -85,13 +85,8 @@ class PeftLoraLoaderMixinTests:
|
||||
unet_kwargs = None
|
||||
transformer_cls = None
|
||||
transformer_kwargs = None
|
||||
vae_cls = AutoencoderKL
|
||||
vae_kwargs = None
|
||||
|
||||
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
|
||||
output_identifier_attribute = "images"
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
|
||||
if self.unet_kwargs and self.transformer_kwargs:
|
||||
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
|
||||
@@ -110,7 +105,7 @@ class PeftLoraLoaderMixinTests:
|
||||
scheduler = scheduler_cls(**self.scheduler_kwargs)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = self.vae_cls(**self.vae_kwargs)
|
||||
vae = AutoencoderKL(**self.vae_kwargs)
|
||||
|
||||
text_encoder = self.text_encoder_cls.from_pretrained(self.text_encoder_id)
|
||||
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id)
|
||||
@@ -126,7 +121,7 @@ class PeftLoraLoaderMixinTests:
|
||||
text_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=rank,
|
||||
target_modules=self.text_encoder_target_modules,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
@@ -217,7 +212,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs()
|
||||
output_no_lora = getattr(pipe(**inputs), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
@@ -235,7 +230,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -249,7 +244,7 @@ class PeftLoraLoaderMixinTests:
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
output_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
|
||||
)
|
||||
@@ -269,7 +264,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -283,36 +278,32 @@ class PeftLoraLoaderMixinTests:
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
output_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
|
||||
)
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
output_lora_scale = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}),
|
||||
self.output_identifier_attribute,
|
||||
)
|
||||
output_lora_scale = pipe(
|
||||
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
|
||||
).images
|
||||
else:
|
||||
output_lora_scale = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5}),
|
||||
self.output_identifier_attribute,
|
||||
)
|
||||
output_lora_scale = pipe(
|
||||
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5}
|
||||
).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
|
||||
"Lora + scale should change the output",
|
||||
)
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
output_lora_0_scale = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}),
|
||||
self.output_identifier_attribute,
|
||||
)
|
||||
output_lora_0_scale = pipe(
|
||||
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
|
||||
).images
|
||||
else:
|
||||
output_lora_0_scale = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0}),
|
||||
self.output_identifier_attribute,
|
||||
)
|
||||
output_lora_0_scale = pipe(
|
||||
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0}
|
||||
).images
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
|
||||
"Lora + 0 scale should lead to same result as no LoRA",
|
||||
@@ -333,7 +324,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -356,7 +347,7 @@ class PeftLoraLoaderMixinTests:
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
ouput_fused = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertFalse(
|
||||
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
|
||||
)
|
||||
@@ -376,7 +367,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -403,7 +394,7 @@ class PeftLoraLoaderMixinTests:
|
||||
"Lora not correctly unloaded in text encoder 2",
|
||||
)
|
||||
|
||||
ouput_unloaded = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
|
||||
"Fused lora should change the output",
|
||||
@@ -423,7 +414,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -436,7 +427,7 @@ class PeftLoraLoaderMixinTests:
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
images_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
|
||||
@@ -470,9 +461,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
|
||||
images_lora_from_pretrained = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
@@ -511,7 +500,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -538,7 +527,7 @@ class PeftLoraLoaderMixinTests:
|
||||
}
|
||||
)
|
||||
|
||||
output_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
|
||||
)
|
||||
@@ -547,9 +536,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(state_dict)
|
||||
|
||||
output_partial_lora = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
|
||||
"Removing adapters should change the output",
|
||||
@@ -569,7 +556,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -582,7 +569,7 @@ class PeftLoraLoaderMixinTests:
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
images_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
@@ -602,9 +589,7 @@ class PeftLoraLoaderMixinTests:
|
||||
"Lora not correctly set in text encoder 2",
|
||||
)
|
||||
|
||||
images_lora_save_pretrained = getattr(
|
||||
pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
|
||||
@@ -628,7 +613,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -648,7 +633,7 @@ class PeftLoraLoaderMixinTests:
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
images_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
|
||||
@@ -681,9 +666,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
|
||||
images_lora_from_pretrained = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
|
||||
@@ -714,7 +697,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -733,36 +716,32 @@ class PeftLoraLoaderMixinTests:
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
output_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
|
||||
)
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
output_lora_scale = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}),
|
||||
self.output_identifier_attribute,
|
||||
)
|
||||
output_lora_scale = pipe(
|
||||
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
|
||||
).images
|
||||
else:
|
||||
output_lora_scale = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5}),
|
||||
self.output_identifier_attribute,
|
||||
)
|
||||
output_lora_scale = pipe(
|
||||
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5}
|
||||
).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
|
||||
"Lora + scale should change the output",
|
||||
)
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
output_lora_0_scale = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}),
|
||||
self.output_identifier_attribute,
|
||||
)
|
||||
output_lora_0_scale = pipe(
|
||||
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
|
||||
).images
|
||||
else:
|
||||
output_lora_0_scale = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0}),
|
||||
self.output_identifier_attribute,
|
||||
)
|
||||
output_lora_0_scale = pipe(
|
||||
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0}
|
||||
).images
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
|
||||
"Lora + 0 scale should lead to same result as no LoRA",
|
||||
@@ -788,7 +767,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -820,7 +799,7 @@ class PeftLoraLoaderMixinTests:
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
ouput_fused = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertFalse(
|
||||
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
|
||||
)
|
||||
@@ -840,7 +819,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -876,7 +855,7 @@ class PeftLoraLoaderMixinTests:
|
||||
"Lora not correctly unloaded in text encoder 2",
|
||||
)
|
||||
|
||||
ouput_unloaded = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
|
||||
"Fused lora should change the output",
|
||||
@@ -916,15 +895,11 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipe.fuse_lora()
|
||||
|
||||
output_fused_lora = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.unfuse_lora()
|
||||
|
||||
output_unfused_lora = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
# unloading should remove the LoRA layers
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
|
||||
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
@@ -957,7 +932,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
@@ -985,20 +960,14 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipe.set_adapters("adapter-1")
|
||||
|
||||
output_adapter_1 = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters("adapter-2")
|
||||
output_adapter_2 = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
|
||||
output_adapter_mixed = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
# Fuse and unfuse should lead to the same results
|
||||
self.assertFalse(
|
||||
@@ -1018,7 +987,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipe.disable_lora()
|
||||
|
||||
output_disabled = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
|
||||
@@ -1043,7 +1012,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
if self.unet_kwargs is not None:
|
||||
@@ -1064,15 +1033,11 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
|
||||
pipe.set_adapters("adapter-1", weights_1)
|
||||
output_weights_1 = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
weights_2 = {"unet": {"up": 5}}
|
||||
pipe.set_adapters("adapter-1", weights_2)
|
||||
output_weights_2 = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
|
||||
@@ -1088,7 +1053,7 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
pipe.disable_lora()
|
||||
output_disabled = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
|
||||
@@ -1113,7 +1078,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
@@ -1143,20 +1108,14 @@ class PeftLoraLoaderMixinTests:
|
||||
scales_2 = {"unet": {"down": 5, "mid": 5}}
|
||||
pipe.set_adapters("adapter-1", scales_1)
|
||||
|
||||
output_adapter_1 = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters("adapter-2", scales_2)
|
||||
output_adapter_2 = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
|
||||
|
||||
output_adapter_mixed = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
# Fuse and unfuse should lead to the same results
|
||||
self.assertFalse(
|
||||
@@ -1176,7 +1135,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipe.disable_lora()
|
||||
|
||||
output_disabled = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
|
||||
@@ -1189,7 +1148,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
|
||||
if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline", "CogVideoXPipeline"]:
|
||||
if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]:
|
||||
return
|
||||
|
||||
def updown_options(blocks_with_tf, layers_per_block, value):
|
||||
@@ -1294,7 +1253,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
@@ -1323,20 +1282,14 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipe.set_adapters("adapter-1")
|
||||
|
||||
output_adapter_1 = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters("adapter-2")
|
||||
output_adapter_2 = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
|
||||
output_adapter_mixed = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
|
||||
@@ -1354,9 +1307,7 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
pipe.delete_adapters("adapter-1")
|
||||
output_deleted_adapter_1 = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
|
||||
@@ -1364,9 +1315,7 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
pipe.delete_adapters("adapter-2")
|
||||
output_deleted_adapters = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
|
||||
@@ -1388,9 +1337,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
pipe.delete_adapters(["adapter-1", "adapter-2"])
|
||||
|
||||
output_deleted_adapters = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
|
||||
@@ -1412,7 +1359,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
@@ -1441,20 +1388,14 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipe.set_adapters("adapter-1")
|
||||
|
||||
output_adapter_1 = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters("adapter-2")
|
||||
output_adapter_2 = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
|
||||
output_adapter_mixed = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
# Fuse and unfuse should lead to the same results
|
||||
self.assertFalse(
|
||||
@@ -1473,9 +1414,7 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
|
||||
output_adapter_mixed_weighted = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
|
||||
@@ -1484,7 +1423,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
pipe.disable_lora()
|
||||
|
||||
output_disabled = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
|
||||
@@ -1521,11 +1460,7 @@ class PeftLoraLoaderMixinTests:
|
||||
"adapter-1"
|
||||
].weight += float("inf")
|
||||
else:
|
||||
for possible_attn in ["attn", "attn1"]:
|
||||
attn = getattr(pipe.transformer.transformer_blocks[0], possible_attn, None)
|
||||
if attn is not None:
|
||||
attn.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
break
|
||||
pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
@@ -1534,7 +1469,7 @@ class PeftLoraLoaderMixinTests:
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(safe_fusing=False)
|
||||
|
||||
out = getattr(pipe("test", num_inference_steps=2, output_type="np"), self.output_identifier_attribute)
|
||||
out = pipe("test", num_inference_steps=2, output_type="np").images
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
@@ -1655,7 +1590,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
@@ -1686,17 +1621,15 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
# set them to multi-adapter inference mode
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
ouputs_all_lora = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
ouputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters(["adapter-1"])
|
||||
ouputs_lora_1 = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
ouputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.fuse_lora(adapter_names=["adapter-1"])
|
||||
|
||||
# Fusing should still keep the LoRA layers so outpout should remain the same
|
||||
outputs_lora_1_fused = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=1e-3, rtol=1e-3),
|
||||
@@ -1707,9 +1640,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.fuse_lora(adapter_names=["adapter-2", "adapter-1"])
|
||||
|
||||
# Fusing should still keep the LoRA layers
|
||||
output_all_lora_fused = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
np.allclose(output_all_lora_fused, ouputs_all_lora, atol=1e-3, rtol=1e-3),
|
||||
"Fused lora should not change the output",
|
||||
@@ -1729,9 +1660,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_dora_lora = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
@@ -1752,9 +1681,7 @@ class PeftLoraLoaderMixinTests:
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
output_dora_lora = getattr(
|
||||
pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute
|
||||
)
|
||||
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
|
||||
@@ -1800,10 +1727,10 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
# Just makes sure it works..
|
||||
_ = getattr(pipe(**inputs, generator=torch.manual_seed(0)), self.output_identifier_attribute)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
def test_modify_padding_mode(self):
|
||||
if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline", "CogVideoXPipeline"]:
|
||||
if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "FluxPipeline"]:
|
||||
return
|
||||
|
||||
def set_pad_mode(network, mode="circular"):
|
||||
@@ -1824,4 +1751,4 @@ class PeftLoraLoaderMixinTests:
|
||||
set_pad_mode(pipe.unet, _pad_mode)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs()
|
||||
_ = getattr(pipe(**inputs), self.output_identifier_attribute)
|
||||
_ = pipe(**inputs).images
|
||||
|
||||
Reference in New Issue
Block a user