mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] text encoder: read the ranks for all the attn modules (#8324)
* [LoRA] text encoder: read the ranks for all the attn modules * In addition to out_proj, read the ranks of adapters for q_proj, k_proj, and v_proj * Allow missing adapters (UNet already supports this) * ruff format loaders.lora * [LoRA] add tests for partial text encoders LoRAs * [LoRA] update test_simple_inference_with_partial_text_lora to be deterministic * [LoRA] comment justifying test_simple_inference_with_partial_text_lora * style --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -462,17 +462,18 @@ class LoraLoaderMixin:
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
rank_key = f"{name}.out_proj.lora_B.weight"
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
|
||||
if patch_mlp:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
|
||||
@@ -395,6 +395,69 @@ class PeftLoraLoaderMixinTests:
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached on the text encoder
|
||||
with different ranks and some adapters removed
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, _, _ = self.get_dummy_components(scheduler_cls)
|
||||
# Verify `LoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
|
||||
text_lora_config = LoraConfig(
|
||||
r=4,
|
||||
rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
|
||||
lora_alpha=4,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
|
||||
# supports missing layers (PR#8324).
|
||||
state_dict = {
|
||||
f"text_encoder.{module_name}": param
|
||||
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
|
||||
if "text_model.encoder.layers.4" not in module_name
|
||||
}
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
state_dict.update(
|
||||
{
|
||||
f"text_encoder_2.{module_name}": param
|
||||
for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
|
||||
if "text_model.encoder.layers.4" not in module_name
|
||||
}
|
||||
)
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
|
||||
)
|
||||
|
||||
# Unload lora and load it back using the pipe.load_lora_weights machinery
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(state_dict)
|
||||
|
||||
output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
|
||||
"Removing adapters should change the output",
|
||||
)
|
||||
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
|
||||
Reference in New Issue
Block a user