mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fight more tests
This commit is contained in:
@@ -91,6 +91,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
|
||||
|
||||
text_encoder_target_modules = ["q", "k", "v", "o"]
|
||||
test_text_encoder_lora = False
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
@@ -155,18 +156,29 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
@unittest.skip("Text encoder LoRA training is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=5e-3)
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA training is not supported in CogVideoX.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA training is not supported in CogVideoX.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA training is not supported in CogVideoX.")
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
pass
|
||||
|
||||
@@ -89,6 +89,7 @@ class PeftLoraLoaderMixinTests:
|
||||
vae_kwargs = None
|
||||
|
||||
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
test_text_encoder_lora = True
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
|
||||
if self.unet_kwargs and self.transformer_kwargs:
|
||||
@@ -423,8 +424,11 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
if self.test_text_encoder_lora:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -619,13 +623,17 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
if self.test_text_encoder_lora:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config)
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config)
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet")
|
||||
|
||||
@@ -639,7 +647,9 @@ class PeftLoraLoaderMixinTests:
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
|
||||
text_encoder_state_dict = (
|
||||
get_peft_model_state_dict(pipe.text_encoder) if self.test_text_encoder_lora else None
|
||||
)
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.unet)
|
||||
@@ -670,7 +680,12 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
|
||||
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
if self.test_text_encoder_lora:
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
|
||||
|
||||
@@ -882,13 +897,17 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
if self.test_text_encoder_lora:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config)
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config)
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
|
||||
|
||||
@@ -906,8 +925,11 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.unfuse_lora()
|
||||
|
||||
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
# unloading should remove the LoRA layers
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
|
||||
if self.test_text_encoder_lora:
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
|
||||
|
||||
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers")
|
||||
|
||||
@@ -1581,7 +1603,9 @@ class PeftLoraLoaderMixinTests:
|
||||
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
|
||||
|
||||
@require_peft_version_greater(peft_version="0.6.2")
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(
|
||||
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
|
||||
):
|
||||
"""
|
||||
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
|
||||
and makes sure it works as expected - with unet and multi-adapter case
|
||||
@@ -1599,7 +1623,12 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
if self.test_text_encoder_lora:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
else:
|
||||
@@ -1612,7 +1641,6 @@ class PeftLoraLoaderMixinTests:
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
|
||||
|
||||
@@ -1638,7 +1666,7 @@ class PeftLoraLoaderMixinTests:
|
||||
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=1e-3, rtol=1e-3),
|
||||
np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
|
||||
"Fused lora should not change the output",
|
||||
)
|
||||
|
||||
@@ -1648,7 +1676,7 @@ class PeftLoraLoaderMixinTests:
|
||||
# Fusing should still keep the LoRA layers
|
||||
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
np.allclose(output_all_lora_fused, ouputs_all_lora, atol=1e-3, rtol=1e-3),
|
||||
np.allclose(output_all_lora_fused, ouputs_all_lora, atol=expected_atol, rtol=expected_rtol),
|
||||
"Fused lora should not change the output",
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user