diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index 108a79acab..b9fda84d4a 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -91,6 +91,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5" text_encoder_target_modules = ["q", "k", "v", "o"] + test_text_encoder_lora = False @property def output_shape(self): @@ -155,18 +156,29 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): self.assertTrue(np.isnan(out).all()) - @unittest.skip("Text encoder LoRA training is not supported in CogVideoX.") + def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=5e-3) + + @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_partial_text_lora(self): pass - @unittest.skip("Text encoder LoRA training is not supported in CogVideoX.") + @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora(self): pass - @unittest.skip("Text encoder LoRA training is not supported in CogVideoX.") + @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora_and_scale(self): pass - @unittest.skip("Text encoder LoRA training is not supported in CogVideoX.") + @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") def test_simple_inference_with_text_lora_fused(self): pass + + @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_text_lora_save_load(self): + pass + + @unittest.skip("Text encoder LoRA is not supported in CogVideoX.") + def test_simple_inference_with_text_denoiser_lora_unfused(self): + pass diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 29da388915..48019329c2 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -89,6 +89,7 @@ class PeftLoraLoaderMixinTests: vae_kwargs = None text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] + test_text_encoder_lora = True def get_dummy_components(self, scheduler_cls=None, use_dora=False): if self.unet_kwargs and self.transformer_kwargs: @@ -423,8 +424,11 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + if self.test_text_encoder_lora: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) if self.has_two_text_encoders or self.has_three_text_encoders: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: @@ -619,13 +623,17 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config) + if self.test_text_encoder_lora: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config) else: pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet") @@ -639,7 +647,9 @@ class PeftLoraLoaderMixinTests: images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] with tempfile.TemporaryDirectory() as tmpdirname: - text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder) + text_encoder_state_dict = ( + get_peft_model_state_dict(pipe.text_encoder) if self.test_text_encoder_lora else None + ) if self.unet_kwargs is not None: denoiser_state_dict = get_peft_model_state_dict(pipe.unet) @@ -670,7 +680,12 @@ class PeftLoraLoaderMixinTests: pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + + if self.test_text_encoder_lora: + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -882,13 +897,17 @@ class PeftLoraLoaderMixinTests: pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False) - pipe.text_encoder.add_adapter(text_lora_config) + if self.test_text_encoder_lora: + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config) else: pipe.transformer.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -906,8 +925,11 @@ class PeftLoraLoaderMixinTests: pipe.unfuse_lora() output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + # unloading should remove the LoRA layers - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") + if self.test_text_encoder_lora: + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers") @@ -1581,7 +1603,9 @@ class PeftLoraLoaderMixinTests: self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) @require_peft_version_greater(peft_version="0.6.2") - def test_simple_inference_with_text_lora_denoiser_fused_multi(self): + def test_simple_inference_with_text_lora_denoiser_fused_multi( + self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 + ): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected - with unet and multi-adapter case @@ -1599,7 +1623,12 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue(output_no_lora.shape == self.output_shape) - pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + if self.test_text_encoder_lora: + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + if self.unet_kwargs is not None: pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") else: @@ -1612,7 +1641,6 @@ class PeftLoraLoaderMixinTests: else: pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") - self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") @@ -1638,7 +1666,7 @@ class PeftLoraLoaderMixinTests: outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=1e-3, rtol=1e-3), + np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), "Fused lora should not change the output", ) @@ -1648,7 +1676,7 @@ class PeftLoraLoaderMixinTests: # Fusing should still keep the LoRA layers output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - np.allclose(output_all_lora_fused, ouputs_all_lora, atol=1e-3, rtol=1e-3), + np.allclose(output_all_lora_fused, ouputs_all_lora, atol=expected_atol, rtol=expected_rtol), "Fused lora should not change the output", )