1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fight more tests

This commit is contained in:
Aryan
2024-09-17 22:27:10 +02:00
parent 0e1c569c58
commit 7c843949f6
2 changed files with 58 additions and 18 deletions

View File

@@ -91,6 +91,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
text_encoder_target_modules = ["q", "k", "v", "o"]
test_text_encoder_lora = False
@property
def output_shape(self):
@@ -155,18 +156,29 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(np.isnan(out).all())
@unittest.skip("Text encoder LoRA training is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=5e-3)
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA training is not supported in CogVideoX.")
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA training is not supported in CogVideoX.")
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA training is not supported in CogVideoX.")
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_lora_unfused(self):
pass

View File

@@ -89,6 +89,7 @@ class PeftLoraLoaderMixinTests:
vae_kwargs = None
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
test_text_encoder_lora = True
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
if self.unet_kwargs and self.transformer_kwargs:
@@ -423,8 +424,11 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.test_text_encoder_lora:
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
@@ -619,13 +623,17 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config)
if self.test_text_encoder_lora:
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet")
@@ -639,7 +647,9 @@ class PeftLoraLoaderMixinTests:
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
text_encoder_state_dict = (
get_peft_model_state_dict(pipe.text_encoder) if self.test_text_encoder_lora else None
)
if self.unet_kwargs is not None:
denoiser_state_dict = get_peft_model_state_dict(pipe.unet)
@@ -670,7 +680,12 @@ class PeftLoraLoaderMixinTests:
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.test_text_encoder_lora:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
@@ -882,13 +897,17 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config)
if self.test_text_encoder_lora:
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
@@ -906,8 +925,11 @@ class PeftLoraLoaderMixinTests:
pipe.unfuse_lora()
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
# unloading should remove the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
if self.test_text_encoder_lora:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers")
@@ -1581,7 +1603,9 @@ class PeftLoraLoaderMixinTests:
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
@require_peft_version_greater(peft_version="0.6.2")
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
def test_simple_inference_with_text_lora_denoiser_fused_multi(
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case
@@ -1599,7 +1623,12 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.test_text_encoder_lora:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
@@ -1612,7 +1641,6 @@ class PeftLoraLoaderMixinTests:
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
@@ -1638,7 +1666,7 @@ class PeftLoraLoaderMixinTests:
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=1e-3, rtol=1e-3),
np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
@@ -1648,7 +1676,7 @@ class PeftLoraLoaderMixinTests:
# Fusing should still keep the LoRA layers
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_all_lora_fused, ouputs_all_lora, atol=1e-3, rtol=1e-3),
np.allclose(output_all_lora_fused, ouputs_all_lora, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)