diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 06eb3af05e..3eb7569967 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import re from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union @@ -44,13 +43,13 @@ from ..utils import ( set_adapter_layers, set_weights_and_activate_adapters, ) +from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers if is_transformers_available(): - from transformers import CLIPTextModel, CLIPTextModelWithProjection + from transformers import PreTrainedModel - # To be deprecated soon - from ..models.lora import PatchedLoraProjection + from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules if is_accelerate_available(): from accelerate import init_empty_weights @@ -67,37 +66,10 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." -def text_encoder_attn_modules(text_encoder): - attn_modules = [] - - if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): - for i, layer in enumerate(text_encoder.text_model.encoder.layers): - name = f"text_model.encoder.layers.{i}.self_attn" - mod = layer.self_attn - attn_modules.append((name, mod)) - else: - raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}") - - return attn_modules - - -def text_encoder_mlp_modules(text_encoder): - mlp_modules = [] - - if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): - for i, layer in enumerate(text_encoder.text_model.encoder.layers): - mlp_mod = layer.mlp - name = f"text_model.encoder.layers.{i}.mlp" - mlp_modules.append((name, mlp_mod)) - else: - raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}") - - return mlp_modules - - class LoraLoaderMixin: r""" - Load LoRA layers into [`UNet2DConditionModel`] and [`~transformers.CLIPTextModel`]. + Load LoRA layers into [`UNet2DConditionModel`] and + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). """ text_encoder_name = TEXT_ENCODER_NAME @@ -123,28 +95,12 @@ class LoraLoaderMixin: Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - A string (model id of a pretrained model hosted on the Hub), a path to a directory containing the model - weights, or a [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): - Name for referencing the loaded adapter model. If not specified, it will use `default_{i}` where `i` is - the total number of adapters being loaded. Must have PEFT installed to use. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to( - "cuda" - ) - pipeline.load_lora_weights( - "Yntec/pineappleAnimeMix", weight_name="pineappleAnimeMix_pineapple10.1.safetensors", adapter_name="anime" - ) - ``` + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. """ # First, ensure that the checkpoint is a compatible one and can be successfully loaded. state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) @@ -182,7 +138,15 @@ class LoraLoaderMixin: **kwargs, ): r""" - Return state dict and network alphas of the LoRA weights. + Return state dict for lora weights and the network alphas. + + + + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + + This function is experimental and might change in the future. + + Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): @@ -190,7 +154,8 @@ class LoraLoaderMixin: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). @@ -226,6 +191,7 @@ class LoraLoaderMixin: Mirror source to resolve accessibility issues if you're downloading a model in China. We do not guarantee the timeliness or safety of the source, and you should refer to the mirror site for more information. + """ # Load the main state dict first which has the LoRA layers for either of # UNet and text encoder or both. @@ -322,8 +288,8 @@ class LoraLoaderMixin: # Map SDXL blocks correctly. if unet_config is not None: # use unet config to remap block numbers - state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) - state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict) + state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict) return state_dict, network_alphas @@ -363,109 +329,6 @@ class LoraLoaderMixin: weight_name = targeted_files[0] return weight_name - @classmethod - def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5): - # 1. get all state_dict_keys - all_keys = list(state_dict.keys()) - sgm_patterns = ["input_blocks", "middle_block", "output_blocks"] - - # 2. check if needs remapping, if not return original dict - is_in_sgm_format = False - for key in all_keys: - if any(p in key for p in sgm_patterns): - is_in_sgm_format = True - break - - if not is_in_sgm_format: - return state_dict - - # 3. Else remap from SGM patterns - new_state_dict = {} - inner_block_map = ["resnets", "attentions", "upsamplers"] - - # Retrieves # of down, mid and up blocks - input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() - - for layer in all_keys: - if "text" in layer: - new_state_dict[layer] = state_dict.pop(layer) - else: - layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) - if sgm_patterns[0] in layer: - input_block_ids.add(layer_id) - elif sgm_patterns[1] in layer: - middle_block_ids.add(layer_id) - elif sgm_patterns[2] in layer: - output_block_ids.add(layer_id) - else: - raise ValueError(f"Checkpoint not supported because layer {layer} not supported.") - - input_blocks = { - layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] - for layer_id in input_block_ids - } - middle_blocks = { - layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] - for layer_id in middle_block_ids - } - output_blocks = { - layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key] - for layer_id in output_block_ids - } - - # Rename keys accordingly - for i in input_block_ids: - block_id = (i - 1) // (unet_config.layers_per_block + 1) - layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1) - - for key in input_blocks[i]: - inner_block_id = int(key.split(delimiter)[block_slice_pos]) - inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" - inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" - new_key = delimiter.join( - key.split(delimiter)[: block_slice_pos - 1] - + [str(block_id), inner_block_key, inner_layers_in_block] - + key.split(delimiter)[block_slice_pos + 1 :] - ) - new_state_dict[new_key] = state_dict.pop(key) - - for i in middle_block_ids: - key_part = None - if i == 0: - key_part = [inner_block_map[0], "0"] - elif i == 1: - key_part = [inner_block_map[1], "0"] - elif i == 2: - key_part = [inner_block_map[0], "1"] - else: - raise ValueError(f"Invalid middle block id {i}.") - - for key in middle_blocks[i]: - new_key = delimiter.join( - key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] - ) - new_state_dict[new_key] = state_dict.pop(key) - - for i in output_block_ids: - block_id = i // (unet_config.layers_per_block + 1) - layer_in_block_id = i % (unet_config.layers_per_block + 1) - - for key in output_blocks[i]: - inner_block_id = int(key.split(delimiter)[block_slice_pos]) - inner_block_key = inner_block_map[inner_block_id] - inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0" - new_key = delimiter.join( - key.split(delimiter)[: block_slice_pos - 1] - + [str(block_id), inner_block_key, inner_layers_in_block] - + key.split(delimiter)[block_slice_pos + 1 :] - ) - new_state_dict[new_key] = state_dict.pop(key) - - if len(state_dict) > 0: - raise ValueError("At this point all state dict entries have to be converted.") - - return new_state_dict - @classmethod def _optionally_disable_offloading(cls, _pipeline): """ @@ -502,27 +365,25 @@ class LoraLoaderMixin: cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None ): """ - Load LoRA layers specified in `state_dict` into `unet`. + This will load the LoRA layers specified in `state_dict` into `unet`. Parameters: state_dict (`dict`): - A standard state dict containing the LoRA layer parameters. The keys can either be indexed directly - into the `unet` or prefixed with an additional `unet`, which can be used to distinguish between text - encoder LoRA layers. + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. network_alphas (`Dict[str, float]`): - See - [`LoRALinearLayer`](https://github.com/huggingface/diffusers/blob/c697f524761abd2314c030221a3ad2f7791eab4e/src/diffusers/models/lora.py#L182) - for more details. + See `LoRALinearLayer` for more details. unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Only load and not initialize the pretrained weights. This can speedup model loading and also tries to - not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only - supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to - `True` will raise an error. + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. adapter_name (`str`, *optional*): - Name for referencing the loaded adapter model. If not specified, it will use `default_{i}` where `i` is - the total number of adapters being loaded. + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. """ low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -616,27 +477,26 @@ class LoraLoaderMixin: _pipeline=None, ): """ - Load LoRA layers specified in `state_dict` into `text_encoder`. + This will load the LoRA layers specified in `state_dict` into `text_encoder` Parameters: state_dict (`dict`): - A standard state dict containing the LoRA layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between UNet LoRA layers. + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. network_alphas (`Dict[str, float]`): - See - [`LoRALinearLayer`](https://github.com/huggingface/diffusers/blob/c697f524761abd2314c030221a3ad2f7791eab4e/src/diffusers/models/lora.py#L182) - for more details. + See `LoRALinearLayer` for more details. text_encoder (`CLIPTextModel`): The text encoder model to load the LoRA layers into. prefix (`str`): Expected prefix of the `text_encoder` in the `state_dict`. lora_scale (`float`): - Scale of `LoRALinearLayer`'s output before it is added with the output of the regular LoRA layer. + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Only load and not initialize the pretrained weights. This can speedup model loading and also tries to - not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only - supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to - `True` will raise an error. + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. @@ -921,11 +781,11 @@ class LoraLoaderMixin: safe_serialization: bool = True, ): r""" - Save the UNet and text encoder LoRA parameters. + Save the LoRA parameters corresponding to the UNet and text encoder. Arguments: save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to (will be created if it doesn't exist). + Directory to save LoRA parameters to. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `unet`. text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): @@ -936,30 +796,11 @@ class LoraLoaderMixin: need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): - The function to use to save the state dict. Useful during distributed training when you need to replace - `torch.save` with another method. Can be configured with the environment variable + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or with `pickle`. - - Example: - - ```py - from diffusers import StableDiffusionXLPipeline - from peft.utils import get_peft_model_state_dict - import torch - - pipeline = StableDiffusionXLPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora() - - # get and save unet state dict - unet_state_dict = get_peft_model_state_dict(pipeline.unet, adapter_name="pixel") - pipeline.save_lora_weights("fused-model", unet_lora_layers=unet_state_dict) - pipeline.load_lora_weights("fused-model", weight_name="pytorch_lora_weights.safetensors") - ``` + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. """ # Create a flat dictionary. state_dict = {} @@ -1028,186 +869,16 @@ class LoraLoaderMixin: save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") - @classmethod - def _convert_kohya_lora_to_diffusers(cls, state_dict): - unet_state_dict = {} - te_state_dict = {} - te2_state_dict = {} - network_alphas = {} - - # every down weight has a corresponding up weight and potentially an alpha weight - lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] - for key in lora_keys: - lora_name = key.split(".")[0] - lora_name_up = lora_name + ".lora_up.weight" - lora_name_alpha = lora_name + ".alpha" - - if lora_name.startswith("lora_unet_"): - diffusers_name = key.replace("lora_unet_", "").replace("_", ".") - - if "input.blocks" in diffusers_name: - diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") - else: - diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") - - if "middle.block" in diffusers_name: - diffusers_name = diffusers_name.replace("middle.block", "mid_block") - else: - diffusers_name = diffusers_name.replace("mid.block", "mid_block") - if "output.blocks" in diffusers_name: - diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") - else: - diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") - - diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") - diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") - diffusers_name = diffusers_name.replace("proj.in", "proj_in") - diffusers_name = diffusers_name.replace("proj.out", "proj_out") - diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") - - # SDXL specificity. - if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: - pattern = r"\.\d+(?=\D*$)" - diffusers_name = re.sub(pattern, "", diffusers_name, count=1) - if ".in." in diffusers_name: - diffusers_name = diffusers_name.replace("in.layers.2", "conv1") - if ".out." in diffusers_name: - diffusers_name = diffusers_name.replace("out.layers.3", "conv2") - if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: - diffusers_name = diffusers_name.replace("op", "conv") - if "skip" in diffusers_name: - diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") - - # LyCORIS specificity. - if "time.emb.proj" in diffusers_name: - diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") - if "conv.shortcut" in diffusers_name: - diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") - - # General coverage. - if "transformer_blocks" in diffusers_name: - if "attn1" in diffusers_name or "attn2" in diffusers_name: - diffusers_name = diffusers_name.replace("attn1", "attn1.processor") - diffusers_name = diffusers_name.replace("attn2", "attn2.processor") - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "ff" in diffusers_name: - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif any(key in diffusers_name for key in ("proj_in", "proj_out")): - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - else: - unet_state_dict[diffusers_name] = state_dict.pop(key) - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - - elif lora_name.startswith("lora_te_"): - diffusers_name = key.replace("lora_te_", "").replace("_", ".") - diffusers_name = diffusers_name.replace("text.model", "text_model") - diffusers_name = diffusers_name.replace("self.attn", "self_attn") - diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - if "self_attn" in diffusers_name: - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "mlp" in diffusers_name: - # Be aware that this is the new diffusers convention and the rest of the code might - # not utilize it yet. - diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - - # (sayakpaul): Duplicate code. Needs to be cleaned. - elif lora_name.startswith("lora_te1_"): - diffusers_name = key.replace("lora_te1_", "").replace("_", ".") - diffusers_name = diffusers_name.replace("text.model", "text_model") - diffusers_name = diffusers_name.replace("self.attn", "self_attn") - diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - if "self_attn" in diffusers_name: - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "mlp" in diffusers_name: - # Be aware that this is the new diffusers convention and the rest of the code might - # not utilize it yet. - diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") - te_state_dict[diffusers_name] = state_dict.pop(key) - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - - # (sayakpaul): Duplicate code. Needs to be cleaned. - elif lora_name.startswith("lora_te2_"): - diffusers_name = key.replace("lora_te2_", "").replace("_", ".") - diffusers_name = diffusers_name.replace("text.model", "text_model") - diffusers_name = diffusers_name.replace("self.attn", "self_attn") - diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - if "self_attn" in diffusers_name: - te2_state_dict[diffusers_name] = state_dict.pop(key) - te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - elif "mlp" in diffusers_name: - # Be aware that this is the new diffusers convention and the rest of the code might - # not utilize it yet. - diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") - te2_state_dict[diffusers_name] = state_dict.pop(key) - te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) - - # Rename the alphas so that they can be mapped appropriately. - if lora_name_alpha in state_dict: - alpha = state_dict.pop(lora_name_alpha).item() - if lora_name_alpha.startswith("lora_unet_"): - prefix = "unet." - elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): - prefix = "text_encoder." - else: - prefix = "text_encoder_2." - new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" - network_alphas.update({new_name: alpha}) - - if len(state_dict) > 0: - raise ValueError( - f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}" - ) - - logger.info("Kohya-style checkpoint detected.") - unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} - te_state_dict = { - f"{cls.text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items() - } - te2_state_dict = ( - {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()} - if len(te2_state_dict) > 0 - else None - ) - if te2_state_dict is not None: - te_state_dict.update(te2_state_dict) - - new_state_dict = {**unet_state_dict, **te_state_dict} - return new_state_dict, network_alphas - def unload_lora_weights(self): """ - Unload the LoRA parameters from a pipeline. + Unloads the LoRA parameters. Examples: - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.unload_lora_weights() + ```python + >>> # Assuming `pipeline` is already loaded with the LoRA parameters. + >>> pipeline.unload_lora_weights() + >>> ... ``` """ if not USE_PEFT_BACKEND: @@ -1236,7 +907,7 @@ class LoraLoaderMixin: safe_fusing: bool = False, ): r""" - Fuse the LoRA parameters with the original parameters in their corresponding blocks. + Fuses the LoRA parameters into the original parameters of the corresponding blocks. @@ -1250,23 +921,9 @@ class LoraLoaderMixin: Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. lora_scale (`float`, defaults to 1.0): - Controls LoRA influence on the outputs. + Controls how much to influence the outputs with the LoRA parameters. safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for `NaN` values before fusing and if values are `NaN`, then don't fuse - them. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. """ if fuse_unet or fuse_text_encoder: self.num_fused_loras += 1 @@ -1315,7 +972,8 @@ class LoraLoaderMixin: def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): r""" - Unfuse the LoRA parameters from the original parameters in their corresponding blocks. + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora). @@ -1328,20 +986,6 @@ class LoraLoaderMixin: unfuse_text_encoder (`bool`, defaults to `True`): Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the LoRA parameters then it won't have any effect. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - pipeline.unfuse_lora() - ``` """ if unfuse_unet: if not USE_PEFT_BACKEND: @@ -1393,32 +1037,16 @@ class LoraLoaderMixin: text_encoder_weights: List[float] = None, ): """ - Set the currently active adapter for use in the text encoder. + Sets the adapter layers for the text encoder. Args: adapter_names (`List[str]` or `str`): - The adapter to activate. + The names of the adapters to use. text_encoder (`torch.nn.Module`, *optional*): - The text encoder module to activate the adapter layers for. If `None`, it will try to get the - `text_encoder` attribute. + The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder` + attribute. text_encoder_weights (`List[float]`, *optional*): The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.load_lora_weights( - "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" - ) - pipeline.set_adapters_for_text_encoder("pixel") - ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1444,27 +1072,14 @@ class LoraLoaderMixin: ) set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights) - def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821 + def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): """ - Disable the text encoder's LoRA layers. + Disables the LoRA layers for the text encoder. Args: text_encoder (`torch.nn.Module`, *optional*): The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder` attribute. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.disable_lora_for_text_encoder() - ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1474,27 +1089,14 @@ class LoraLoaderMixin: raise ValueError("Text Encoder not found.") set_adapter_layers(text_encoder, enabled=False) - def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821 + def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): """ - Enables the text encoder's LoRA layers. + Enables the LoRA layers for the text encoder. Args: text_encoder (`torch.nn.Module`, *optional*): The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` attribute. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.enable_lora_for_text_encoder() - ``` """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1545,24 +1147,10 @@ class LoraLoaderMixin: def delete_adapters(self, adapter_names: Union[List[str], str]): """ - Delete an adapter's LoRA layers from the UNet and text encoder(s). - Args: + Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s). adapter_names (`Union[List[str], str]`): - The names (single string or list of strings) of the adapter to delete. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.delete_adapters("pixel") - ``` + The names of the adapter to delete. Can be a single string or a list of strings """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1582,7 +1170,7 @@ class LoraLoaderMixin: def get_active_adapters(self) -> List[str]: """ - Get a list of currently active adapters. + Gets the list of the current active adapters. Example: @@ -1614,22 +1202,7 @@ class LoraLoaderMixin: def get_list_adapters(self) -> Dict[str, List[str]]: """ - Get a list of all currently available adapters for each component in the pipeline. - - Example: - - ```py - from diffusers import DiffusionPipeline - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - ).to("cuda") - pipeline.load_lora_weights( - "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" - ) - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.get_list_adapters() - ``` + Gets the current list of all available adapters in the pipeline. """ if not USE_PEFT_BACKEND: raise ValueError( @@ -1651,27 +1224,14 @@ class LoraLoaderMixin: def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None: """ - Move a LoRA to a target device. Useful for offloading a LoRA to the CPU in case you want to load multiple - adapters and free some GPU memory. + Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case + you want to load multiple adapters and free some GPU memory. Args: adapter_names (`List[str]`): - List of adapters to send to device. + List of adapters to send device to. device (`Union[torch.device, str, int]`): - Device (can be a `torch.device`, `str` or `int`) to place adapters on. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.set_lora_device(["pixel"], device="cuda") - ``` + Device to send the adapters to. Can be either a torch device, a str or an integer. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") @@ -1703,7 +1263,7 @@ class LoraLoaderMixin: class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): - """This class overrides [`LoraLoaderMixin`] with LoRA loading/saving code that's specific to SDXL.""" + """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL""" # Overrride to properly handle the loading and unloading of the additional text encoder. def load_lora_weights( @@ -1728,26 +1288,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - A string (model id of a pretrained model hosted on the Hub), a path to a directory containing the model - weights, or a [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): - Name for referencing the loaded adapter model. If not specified, it will use `default_{i}` where `i` is - the total number of adapters being loaded. Must have PEFT installed to use. - - Example: - - ```py - from diffusers import StableDiffusionXLPipeline - import torch - - pipeline = StableDiffusionXLPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - ``` + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + kwargs (`dict`, *optional*): + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ # We could have accessed the unet config from `lora_state_dict()` too. We pass # it here explicitly to be able to tell that it's coming from an SDXL diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py new file mode 100644 index 0000000000..4a89fc20b5 --- /dev/null +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -0,0 +1,284 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from ..utils import logging + + +logger = logging.get_logger(__name__) + + +def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5): + # 1. get all state_dict_keys + all_keys = list(state_dict.keys()) + sgm_patterns = ["input_blocks", "middle_block", "output_blocks"] + + # 2. check if needs remapping, if not return original dict + is_in_sgm_format = False + for key in all_keys: + if any(p in key for p in sgm_patterns): + is_in_sgm_format = True + break + + if not is_in_sgm_format: + return state_dict + + # 3. Else remap from SGM patterns + new_state_dict = {} + inner_block_map = ["resnets", "attentions", "upsamplers"] + + # Retrieves # of down, mid and up blocks + input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() + + for layer in all_keys: + if "text" in layer: + new_state_dict[layer] = state_dict.pop(layer) + else: + layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) + if sgm_patterns[0] in layer: + input_block_ids.add(layer_id) + elif sgm_patterns[1] in layer: + middle_block_ids.add(layer_id) + elif sgm_patterns[2] in layer: + output_block_ids.add(layer_id) + else: + raise ValueError(f"Checkpoint not supported because layer {layer} not supported.") + + input_blocks = { + layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] + for layer_id in input_block_ids + } + middle_blocks = { + layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] + for layer_id in middle_block_ids + } + output_blocks = { + layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key] + for layer_id in output_block_ids + } + + # Rename keys accordingly + for i in input_block_ids: + block_id = (i - 1) // (unet_config.layers_per_block + 1) + layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1) + + for key in input_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" + inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in middle_block_ids: + key_part = None + if i == 0: + key_part = [inner_block_map[0], "0"] + elif i == 1: + key_part = [inner_block_map[1], "0"] + elif i == 2: + key_part = [inner_block_map[0], "1"] + else: + raise ValueError(f"Invalid middle block id {i}.") + + for key in middle_blocks[i]: + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in output_block_ids: + block_id = i // (unet_config.layers_per_block + 1) + layer_in_block_id = i % (unet_config.layers_per_block + 1) + + for key in output_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] + inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + if len(state_dict) > 0: + raise ValueError("At this point all state dict entries have to be converted.") + + return new_state_dict + + +def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): + unet_state_dict = {} + te_state_dict = {} + te2_state_dict = {} + network_alphas = {} + + # every down weight has a corresponding up weight and potentially an alpha weight + lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] + for key in lora_keys: + lora_name = key.split(".")[0] + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" + + if lora_name.startswith("lora_unet_"): + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + + if "input.blocks" in diffusers_name: + diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") + else: + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + + if "middle.block" in diffusers_name: + diffusers_name = diffusers_name.replace("middle.block", "mid_block") + else: + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + if "output.blocks" in diffusers_name: + diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") + else: + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") + diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") + + # SDXL specificity. + if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name: + pattern = r"\.\d+(?=\D*$)" + diffusers_name = re.sub(pattern, "", diffusers_name, count=1) + if ".in." in diffusers_name: + diffusers_name = diffusers_name.replace("in.layers.2", "conv1") + if ".out." in diffusers_name: + diffusers_name = diffusers_name.replace("out.layers.3", "conv2") + if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: + diffusers_name = diffusers_name.replace("op", "conv") + if "skip" in diffusers_name: + diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") + + # LyCORIS specificity. + if "time.emb.proj" in diffusers_name: + diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj") + if "conv.shortcut" in diffusers_name: + diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut") + + # General coverage. + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "ff" in diffusers_name: + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + else: + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + elif lora_name.startswith("lora_te_"): + diffusers_name = key.replace("lora_te_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # (sayakpaul): Duplicate code. Needs to be cleaned. + elif lora_name.startswith("lora_te1_"): + diffusers_name = key.replace("lora_te1_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # (sayakpaul): Duplicate code. Needs to be cleaned. + elif lora_name.startswith("lora_te2_"): + diffusers_name = key.replace("lora_te2_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te2_state_dict[diffusers_name] = state_dict.pop(key) + te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + te2_state_dict[diffusers_name] = state_dict.pop(key) + te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # Rename the alphas so that they can be mapped appropriately. + if lora_name_alpha in state_dict: + alpha = state_dict.pop(lora_name_alpha).item() + if lora_name_alpha.startswith("lora_unet_"): + prefix = "unet." + elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): + prefix = "text_encoder." + else: + prefix = "text_encoder_2." + new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" + network_alphas.update({new_name: alpha}) + + if len(state_dict) > 0: + raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}") + + logger.info("Kohya-style checkpoint detected.") + unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} + te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()} + te2_state_dict = ( + {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()} + if len(te2_state_dict) > 0 + else None + ) + if te2_state_dict is not None: + te_state_dict.update(te2_state_dict) + + new_state_dict = {**unet_state_dict, **te_state_dict} + return new_state_dict, network_alphas