diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 94910956d0..a6b5d88b12 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -43,7 +43,7 @@ from ..utils import ( set_adapter_layers, set_weights_and_activate_adapters, ) -from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers +from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers if is_transformers_available(): @@ -288,7 +288,7 @@ class LoraLoaderMixin: if unet_config is not None: # use unet config to remap block numbers state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) - state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) return state_dict, network_alphas diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index e233c916f9..338175b751 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -123,134 +123,76 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b return new_state_dict -def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): +def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): + """ + Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict. + + Args: + state_dict (`dict`): The state dict to convert. + unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet". + text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to + "text_encoder". + + Returns: + `tuple`: A tuple containing the converted state dict and a dictionary of alphas. + """ unet_state_dict = {} te_state_dict = {} te2_state_dict = {} network_alphas = {} - is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict) - is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) - is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) - if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora: + # Check for DoRA-enabled LoRAs. + if any( + "dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k) + for k in state_dict + ): if is_peft_version("<", "0.9.0"): raise ValueError( "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." ) - # every down weight has a corresponding up weight and potentially an alpha weight - lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] - for key in lora_keys: + # Iterate over all LoRA weights. + all_lora_keys = list(state_dict.keys()) + for key in all_lora_keys: + if not key.endswith("lora_down.weight"): + continue + + # Extract LoRA name. lora_name = key.split(".")[0] + + # Find corresponding up weight and alpha. lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" + # Handle U-Net LoRAs. if lora_name.startswith("lora_unet_"): - diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + diffusers_name = _convert_unet_lora_key(key) - if "input.blocks" in diffusers_name: - diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") - else: - diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + # Store down and up weights. + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - if "middle.block" in diffusers_name: - diffusers_name = diffusers_name.replace("middle.block", "mid_block") - else: - diffusers_name = diffusers_name.replace("mid.block", "mid_block") - if "output.blocks" in diffusers_name: - diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") - else: - diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") - - diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") - diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") - diffusers_name = diffusers_name.replace("proj.in", "proj_in") - diffusers_name = diffusers_name.replace("proj.out", "proj_out") - diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") - - # SDXL specificity. - if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: - pattern = r"\.\d+(?=\D*$)" - diffusers_name = re.sub(pattern, "", diffusers_name, count=1) - if ".in." in diffusers_name: - diffusers_name = diffusers_name.replace("in.layers.2", "conv1") - if ".out." in diffusers_name: - diffusers_name = diffusers_name.replace("out.layers.3", "conv2") - if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: - diffusers_name = diffusers_name.replace("op", "conv") - if "skip" in diffusers_name: - diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") - - # LyCORIS specificity. - if "time.emb.proj" in diffusers_name: - diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") - if "conv.shortcut" in diffusers_name: - diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") - - # General coverage. - if "transformer_blocks" in diffusers_name: - if "attn1" in diffusers_name or "attn2" in diffusers_name: - diffusers_name = diffusers_name.replace("attn1", "attn1.processor") - diffusers_name = diffusers_name.replace("attn2", "attn2.processor") - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "ff" in diffusers_name: - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif any(key in diffusers_name for key in ("proj_in", "proj_out")): - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - else: - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - - if is_unet_dora_lora: + # Store DoRA scale if present. + if "dora_scale" in state_dict: dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down." unet_state_dict[ diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.") ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) + # Handle text encoder LoRAs. elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): + diffusers_name = _convert_text_encoder_lora_key(key, lora_name) + + # Store down and up weights for te or te2. if lora_name.startswith(("lora_te_", "lora_te1_")): - key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) else: - key_to_replace = "lora_te2_" - - diffusers_name = key.replace(key_to_replace, "").replace("_", ".") - diffusers_name = diffusers_name.replace("text.model", "text_model") - diffusers_name = diffusers_name.replace("self.attn", "self_attn") - diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - diffusers_name = diffusers_name.replace("text.projection", "text_projection") - - if "self_attn" in diffusers_name: - if lora_name.startswith(("lora_te_", "lora_te1_")): - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - else: - te2_state_dict[diffusers_name] = state_dict.pop(key) - te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "mlp" in diffusers_name: - # Be aware that this is the new diffusers convention and the rest of the code might - # not utilize it yet. - diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") - if lora_name.startswith(("lora_te_", "lora_te1_")): - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - else: - te2_state_dict[diffusers_name] = state_dict.pop(key) - te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - # OneTrainer specificity - elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"): te2_state_dict[diffusers_name] = state_dict.pop(key) te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")): + # Store DoRA scale if present. + if "dora_scale" in state_dict: dora_scale_key_to_replace_te = ( "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer." ) @@ -263,22 +205,18 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.") ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale")) - # Rename the alphas so that they can be mapped appropriately. + # Store alpha if present. if lora_name_alpha in state_dict: alpha = state_dict.pop(lora_name_alpha).item() - if lora_name_alpha.startswith("lora_unet_"): - prefix = "unet." - elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): - prefix = "text_encoder." - else: - prefix = "text_encoder_2." - new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" - network_alphas.update({new_name: alpha}) + network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha)) + # Check if any keys remain. if len(state_dict) > 0: raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}") logger.info("Kohya-style checkpoint detected.") + + # Construct final state dict. unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()} te2_state_dict = ( @@ -291,3 +229,100 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_ new_state_dict = {**unet_state_dict, **te_state_dict} return new_state_dict, network_alphas + + +def _convert_unet_lora_key(key): + """ + Converts a U-Net LoRA key to a Diffusers compatible key. + """ + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + + # Replace common U-Net naming patterns. + diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("middle.block", "mid_block") + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") + diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") + + # SDXL specific conversions. + if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: + pattern = r"\.\d+(?=\D*$)" + diffusers_name = re.sub(pattern, "", diffusers_name, count=1) + if ".in." in diffusers_name: + diffusers_name = diffusers_name.replace("in.layers.2", "conv1") + if ".out." in diffusers_name: + diffusers_name = diffusers_name.replace("out.layers.3", "conv2") + if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: + diffusers_name = diffusers_name.replace("op", "conv") + if "skip" in diffusers_name: + diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") + + # LyCORIS specific conversions. + if "time.emb.proj" in diffusers_name: + diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") + if "conv.shortcut" in diffusers_name: + diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") + + # General conversions. + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + elif "ff" in diffusers_name: + pass + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + pass + else: + pass + + return diffusers_name + + +def _convert_text_encoder_lora_key(key, lora_name): + """ + Converts a text encoder LoRA key to a Diffusers compatible key. + """ + if lora_name.startswith(("lora_te_", "lora_te1_")): + key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_" + else: + key_to_replace = "lora_te2_" + + diffusers_name = key.replace(key_to_replace, "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("text.projection", "text_projection") + + if "self_attn" in diffusers_name or "text_projection" in diffusers_name: + pass + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + return diffusers_name + + +def _get_alpha_name(lora_name_alpha, diffusers_name, alpha): + """ + Gets the correct alpha name for the Diffusers model. + """ + if lora_name_alpha.startswith("lora_unet_"): + prefix = "unet." + elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): + prefix = "text_encoder." + else: + prefix = "text_encoder_2." + new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" + return {new_name: alpha}