From 298ce679992e439a9869632579a2f853ea2b095e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=A6ros?= <163143799+elias-gaeros@users.noreply.github.com> Date: Tue, 18 Jun 2024 22:10:50 +0200 Subject: [PATCH] [LoRA] text encoder: read the ranks for all the attn modules (#8324) * [LoRA] text encoder: read the ranks for all the attn modules * In addition to out_proj, read the ranks of adapters for q_proj, k_proj, and v_proj * Allow missing adapters (UNet already supports this) * ruff format loaders.lora * [LoRA] add tests for partial text encoders LoRAs * [LoRA] update test_simple_inference_with_partial_text_lora to be deterministic * [LoRA] comment justifying test_simple_inference_with_partial_text_lora * style --------- Co-authored-by: Sayak Paul --- src/diffusers/loaders/lora.py | 21 ++++++------ tests/lora/utils.py | 63 +++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 10 deletions(-) diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 120a89533d..194183896d 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -462,17 +462,18 @@ class LoraLoaderMixin: text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) for name, _ in text_encoder_attn_modules(text_encoder): - rank_key = f"{name}.out_proj.lora_B.weight" - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) - if patch_mlp: - for name, _ in text_encoder_mlp_modules(text_encoder): - rank_key_fc1 = f"{name}.fc1.lora_B.weight" - rank_key_fc2 = f"{name}.fc2.lora_B.weight" - - rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] - rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] if network_alphas is not None: alpha_keys = [ diff --git a/tests/lora/utils.py b/tests/lora/utils.py index d08a266456..9a07727db9 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -395,6 +395,69 @@ class PeftLoraLoaderMixinTests: "Loading from saved checkpoints should give same results.", ) + def test_simple_inference_with_partial_text_lora(self): + """ + Tests a simple inference with lora attached on the text encoder + with different ranks and some adapters removed + and makes sure it works as expected + """ + for scheduler_cls in [DDIMScheduler, LCMScheduler]: + components, _, _ = self.get_dummy_components(scheduler_cls) + # Verify `LoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). + text_lora_config = LoraConfig( + r=4, + rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3}, + lora_alpha=4, + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + init_lora_weights=False, + use_dora=False, + ) + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config) + self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") + # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` + # supports missing layers (PR#8324). + state_dict = { + f"text_encoder.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() + if "text_model.encoder.layers.4" not in module_name + } + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config) + self.assertTrue( + check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + state_dict.update( + { + f"text_encoder_2.{module_name}": param + for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() + if "text_model.encoder.layers.4" not in module_name + } + ) + + output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue( + not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" + ) + + # Unload lora and load it back using the pipe.load_lora_weights machinery + pipe.unload_lora_weights() + pipe.load_lora_weights(state_dict) + + output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue( + not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), + "Removing adapters should change the output", + ) + def test_simple_inference_save_pretrained(self): """ Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained