From 79371661d1eda82bad6861702faafc854c631234 Mon Sep 17 00:00:00 2001 From: Valeriy Selitskiy <239034+iamwavecut@users.noreply.github.com> Date: Tue, 6 May 2025 15:14:58 +0200 Subject: [PATCH] [lora_conversion] Enhance key handling for OneTrainer components in LORA conversion utility (#11441) (#11487) * [lora_conversion] Enhance key handling for OneTrainer components in LORA conversion utility (#11441) * Update src/diffusers/loaders/lora_conversion_utils.py Co-authored-by: Sayak Paul --------- Co-authored-by: Sayak Paul --- .../loaders/lora_conversion_utils.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index a9e154af3c..d5fa7dcfc3 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -727,8 +727,25 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict): elif k.startswith("lora_te1_"): has_te_keys = True continue + elif k.startswith("lora_transformer_context_embedder"): + diffusers_key = "context_embedder" + elif k.startswith("lora_transformer_norm_out_linear"): + diffusers_key = "norm_out.linear" + elif k.startswith("lora_transformer_proj_out"): + diffusers_key = "proj_out" + elif k.startswith("lora_transformer_x_embedder"): + diffusers_key = "x_embedder" + elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}" + elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.text_embedder.linear_{i}" + elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"): + i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1]) + diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}" else: - raise NotImplementedError + raise NotImplementedError(f"Handling for key ({k}) is not implemented.") if "attn_" in k: if "_to_out_0" in k: