diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py index 60962e0485..5db5e87f4d 100644 --- a/tests/lora/test_lora_layers_cogvideox.py +++ b/tests/lora/test_lora_layers_cogvideox.py @@ -15,19 +15,33 @@ import sys import unittest +import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel -from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel -from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend +from diffusers import ( + AutoencoderKLCogVideoX, + CogVideoXDDIMScheduler, + CogVideoXDPMScheduler, + CogVideoXPipeline, + CogVideoXTransformer3DModel, +) +from diffusers.utils.testing_utils import ( + floats_tensor, + is_peft_available, + require_peft_backend, + skip_mps, + torch_device, +) if is_peft_available(): - pass + from peft import LoraConfig + from peft.utils import get_peft_model_state_dict sys.path.append(".") -from utils import PeftLoraLoaderMixinTests # noqa: E402 +from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402 @require_peft_backend @@ -79,8 +93,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): text_encoder_target_modules = ["q", "k", "v", "o"] - output_identifier_attribute = "frames" - @property def output_shape(self): return (1, 9, 16, 16, 3) @@ -90,7 +102,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): sequence_length = 16 num_channels = 4 num_frames = 9 - num_latent_frames = 3 # (9 - 1) // temporal_compression_ratio + 1 + num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1 sizes = (2, 2) generator = torch.manual_seed(0) @@ -113,38 +125,101 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): return noise, input_ids, pipeline_inputs + @skip_mps def test_lora_fuse_nan(self): - # TODO(aryan): Stop fighting me and just work! - pass + scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler] + for scheduler_cls in scheduler_classes: + components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") + + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer + self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser") + + # corrupt one LoRA weight with `inf` values + with torch.no_grad(): + pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") + + # with `safe_fusing=True` we should see an Error + with self.assertRaises(ValueError): + pipe.fuse_lora(safe_fusing=True) + + # without we should not see an error, but every image will be black + pipe.fuse_lora(safe_fusing=False) + + out = pipe( + "test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np" + )[0] + + self.assertTrue(np.isnan(out).all()) def test_simple_inference_with_partial_text_lora(self): - # TODO(aryan): Stop fighting me and just work! - pass + """ + Tests a simple inference with lora attached on the text encoder + with different ranks and some adapters removed + and makes sure it works as expected + """ - def test_simple_inference_with_text_denoiser_block_scale(self): - # TODO(aryan): Stop fighting me and just work! - pass + scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler] + for scheduler_cls in scheduler_classes: + components, _, _ = self.get_dummy_components(scheduler_cls) + rank_pattern = dict(zip(self.text_encoder_target_modules, [1, 2, 3])) + text_lora_config = LoraConfig( + r=4, + rank_pattern=rank_pattern, + lora_alpha=4, + target_modules=self.text_encoder_target_modules, + init_lora_weights=False, + use_dora=False, + ) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) - def test_simple_inference_with_text_denoiser_lora_and_scale(self): - # TODO(aryan): Stop fighting me and just work! - pass + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue(output_no_lora.shape == self.output_shape) - def test_simple_inference_with_text_denoiser_lora_save_load(self): - # TODO(aryan): Stop fighting me and just work! - pass + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` + # supports missing layers (PR#8324). + state_dict = { + f"text_encoder.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() + if "block.4.layer" not in module_name + } + + output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + not np.allclose(output_lora, output_no_lora, atol=1e-4, rtol=1e-4), "Lora should change the output" + ) + + # Unload lora and load it back using the pipe.load_lora_weights machinery + pipe.unload_lora_weights() + + pipe.load_lora_weights(state_dict) + + output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] + self.assertTrue( + not np.allclose(output_partial_lora, output_lora, atol=1e-4, rtol=1e-4), + "Removing adapters should change the output", + ) def test_simple_inference_with_text_lora(self): - # TODO(aryan): Stop fighting me and just work! - pass + # We need a lower expected max diff than other lora pipelines apparently + super().test_simple_inference_with_text_lora(expected_atol=1e-4, expected_rtol=1e-4) def test_simple_inference_with_text_lora_and_scale(self): - # TODO(aryan): Stop fighting me and just work! - pass + # We need a lower expected max diff than other lora pipelines apparently + super().test_simple_inference_with_text_lora_and_scale(expected_atol=1e-4, expected_rtol=1e-4) def test_simple_inference_with_text_lora_fused(self): - # TODO(aryan): Stop fighting me and just work! - pass - - def test_simple_inference_with_text_lora_save_load(self): - # TODO(aryan): Stop fighting me and just work! - pass + # We need a lower expected max diff than other lora pipelines apparently + super().test_simple_inference_with_text_lora_fused(expected_atol=1e-4, expected_rtol=1e-4) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 00813d9ac2..3d03fb576f 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -205,6 +205,9 @@ class PeftLoraLoaderMixinTests: """ Tests a simple inference and makes sure it works as expected """ + # TODO(aryan): Some of the assumptions made here in many different tests are incorrect for CogVideoX. + # For example, we need to test with CogVideoXDDIMScheduler and CogVideoDPMScheduler instead of DDIMScheduler + # and LCMScheduler, which are not supported by it. scheduler_classes = ( [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) @@ -218,7 +221,7 @@ class PeftLoraLoaderMixinTests: output_no_lora = pipe(**inputs)[0] self.assertTrue(output_no_lora.shape == self.output_shape) - def test_simple_inference_with_text_lora(self): + def test_simple_inference_with_text_lora(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3): """ Tests a simple inference with lora attached on the text encoder and makes sure it works as expected @@ -249,10 +252,11 @@ class PeftLoraLoaderMixinTests: output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" + not np.allclose(output_lora, output_no_lora, atol=expected_atol, rtol=expected_rtol), + "Lora should change the output", ) - def test_simple_inference_with_text_lora_and_scale(self): + def test_simple_inference_with_text_lora_and_scale(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3): """ Tests a simple inference with lora attached on the text encoder + scale argument and makes sure it works as expected @@ -260,6 +264,13 @@ class PeftLoraLoaderMixinTests: scheduler_classes = ( [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) + call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() + for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]: + if possible_attention_kwargs in call_signature_keys: + attention_kwargs_name = possible_attention_kwargs + break + assert attention_kwargs_name is not None + for scheduler_cls in scheduler_classes: components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) @@ -283,36 +294,27 @@ class PeftLoraLoaderMixinTests: output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertTrue( - not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" + not np.allclose(output_lora, output_no_lora, atol=expected_atol, rtol=expected_rtol), + "Lora should change the output", ) - if self.unet_kwargs is not None: - output_lora_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} - )[0] - else: - output_lora_scale = pipe( - **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5} - )[0] + attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} + output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertTrue( - not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), + not np.allclose(output_lora, output_lora_scale, atol=expected_atol, rtol=expected_rtol), "Lora + scale should change the output", ) - if self.unet_kwargs is not None: - output_lora_0_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} - )[0] - else: - output_lora_0_scale = pipe( - **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0} - )[0] + attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} + output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertTrue( - np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), + np.allclose(output_no_lora, output_lora_0_scale, atol=expected_atol, rtol=expected_rtol), "Lora + 0 scale should lead to same result as no LoRA", ) - def test_simple_inference_with_text_lora_fused(self): + def test_simple_inference_with_text_lora_fused(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3): """ Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model and makes sure it works as expected @@ -352,7 +354,8 @@ class PeftLoraLoaderMixinTests: ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertFalse( - np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + np.allclose(ouput_fused, output_no_lora, atol=expected_atol, rtol=expected_rtol), + "Fused lora should change the output", ) def test_simple_inference_with_text_lora_unloaded(self): @@ -606,9 +609,6 @@ class PeftLoraLoaderMixinTests: scheduler_classes = ( [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) - scheduler_classes = ( - [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] - ) for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) @@ -693,6 +693,13 @@ class PeftLoraLoaderMixinTests: scheduler_classes = ( [FlowMatchEulerDiscreteScheduler] if self.uses_flow_matching else [DDIMScheduler, LCMScheduler] ) + call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() + for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]: + if possible_attention_kwargs in call_signature_keys: + attention_kwargs_name = possible_attention_kwargs + break + assert attention_kwargs_name is not None + for scheduler_cls in scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) @@ -724,36 +731,32 @@ class PeftLoraLoaderMixinTests: not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" ) - if self.unet_kwargs is not None: - output_lora_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} - )[0] - else: - output_lora_scale = pipe( - **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5} - )[0] + attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} + output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertTrue( not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), "Lora + scale should change the output", ) - if self.unet_kwargs is not None: - output_lora_0_scale = pipe( - **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0} - )[0] - else: - output_lora_0_scale = pipe( - **inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0} - )[0] + attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} + output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] + self.assertTrue( np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), "Lora + 0 scale should lead to same result as no LoRA", ) - self.assertTrue( - pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, - "The scaling parameter has not been correctly restored!", - ) + if hasattr(pipe.text_encoder, "text_model"): + self.assertTrue( + pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, + "The scaling parameter has not been correctly restored!", + ) + else: + self.assertTrue( + pipe.text_encoder.encoder.block[0].layer[0].SelfAttention.q.scaling["default"] == 1.0, + "The scaling parameter has not been correctly restored!", + ) def test_simple_inference_with_text_lora_denoiser_fused(self): """ @@ -802,9 +805,9 @@ class PeftLoraLoaderMixinTests: check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" ) - ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] + output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] self.assertFalse( - np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" + np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" ) def test_simple_inference_with_text_denoiser_lora_unloaded(self): @@ -1002,7 +1005,7 @@ class PeftLoraLoaderMixinTests: Tests a simple inference with lora attached to text encoder and unet, attaches one adapter and set differnt weights for different blocks (i.e. block lora) """ - if self.pipeline_class.__name__ == "StableDiffusion3Pipeline": + if self.pipeline_class.__name__ in ["StableDiffusion3Pipeline", "CogVideoXPipeline"]: return scheduler_classes = (