From 4a4cdd6b07a36bbf58643e96c9a16d3851ca5bc5 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 28 Jul 2023 23:19:49 +0530 Subject: [PATCH] [Feat] Support SDXL Kohya-style LoRA (#4287) * sdxl lora changes. * better name replacement. * better replacement. * debugging * debugging * debugging * debugging * debugging * remove print. * print state dict keys. * print * distingisuih better * debuggable. * fxi: tyests * fix: arg from training script. * access from class. * run style * debug * save intermediate * some simplifications for SDXL LoRA * styling * unet config is not needed in diffusers format. * fix: dynamic SGM block mapping for SDXL kohya loras (#4322) * Use lora compatible layers for linear proj_in/proj_out (#4323) * improve condition for using the sgm_diffusers mapping * informative comment. * load compatible keys and embedding layer maaping. * Get SDXL 1.0 example lora to load * simplify * specif ranks and hidden sizes. * better handling of k rank and hidden * debug * debug * debug * debug * debug * fix: alpha keys * add check for handling LoRAAttnAddedKVProcessor * sanity comment * modifications for text encoder SDXL * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * denugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * up * up * up * up * up * up * unneeded comments. * unneeded comments. * kwargs for the other attention processors. * kwargs for the other attention processors. * debugging * debugging * debugging * debugging * improve * debugging * debugging * more print * Fix alphas * debugging * debugging * debugging * debugging * debugging * debugging * clean up * clean up. * debugging * fix: text --------- Co-authored-by: Patrick von Platen Co-authored-by: Batuhan Taskaya --- docs/source/en/training/lora.md | 50 +- examples/dreambooth/train_dreambooth_lora.py | 6 +- .../dreambooth/train_dreambooth_lora_sdxl.py | 8 +- src/diffusers/loaders.py | 527 +++++++++++++----- src/diffusers/models/attention_processor.py | 75 ++- src/diffusers/models/lora.py | 11 +- src/diffusers/models/resnet.py | 15 +- src/diffusers/models/transformer_2d.py | 6 +- .../pipeline_stable_diffusion_xl.py | 19 +- tests/models/test_lora_layers.py | 3 +- 10 files changed, 550 insertions(+), 170 deletions(-) diff --git a/docs/source/en/training/lora.md b/docs/source/en/training/lora.md index 670a946581..fd88d74854 100644 --- a/docs/source/en/training/lora.md +++ b/docs/source/en/training/lora.md @@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so: lora_model_id = "sayakpaul/civitai-light-shadow-lora" lora_filename = "light_and_shadow.safetensors" pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename) -``` \ No newline at end of file +``` + +### Supporting Stable Diffusion XL LoRAs trained using the Kohya-trainer + +With this [PR](https://github.com/huggingface/diffusers/pull/4287), there should now be better support for loading Kohya-style LoRAs trained on Stable Diffusion XL (SDXL). + +Here are some example checkpoints we tried out: + +* SDXL 0.9: + * https://civitai.com/models/22279?modelVersionId=118556 + * https://civitai.com/models/104515/sdxlor30costumesrevue-starlight-saijoclaudine-lora + * https://civitai.com/models/108448/daiton-sdxl-test + * https://filebin.net/2ntfqqnapiu9q3zx/pixelbuildings128-v1.safetensors +* SDXL 1.0: + * https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors + +Here is an example of how to perform inference with these checkpoints in `diffusers`: + +```python +from diffusers import DiffusionPipeline +import torch + +base_model_id = "stabilityai/stable-diffusion-xl-base-0.9" +pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda") +pipeline.load_lora_weights(".", weight_name="Kamepan.safetensors") + +prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, offical wallpaper, glint " +negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad propotions" +generator = torch.manual_seed(2947883060) +num_inference_steps = 30 +guidance_scale = 7 + +image = pipeline( + prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, + generator=generator, guidance_scale=guidance_scale +).images[0] +image.save("Kamepan.png") +``` + +`Kamepan.safetensors` comes from https://civitai.com/models/22279?modelVersionId=118556 . + +If you notice carefully, the inference UX is exactly identical to what we presented in the sections above. + +Thanks to [@isidentical](https://github.com/isidentical) for helping us on integrating this feature. + +### Known limitations specific to the Kohya-styled LoRAs + +* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue. +* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736). \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 8bd9f5d722..f6c9990a37 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -925,10 +925,10 @@ def main(args): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_ ) accelerator.register_save_state_pre_hook(save_model_hook) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 443f2d25a7..0383ab4b99 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -825,13 +825,13 @@ def main(args): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_ + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ ) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_ + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ ) accelerator.register_save_state_pre_hook(save_model_hook) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e5b5062591..6a6e03117e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import re import warnings from collections import defaultdict from contextlib import nullcontext @@ -56,7 +57,6 @@ UNET_NAME = "unet" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" -TOTAL_EXAMPLE_KEYS = 5 TEXT_INVERSION_NAME = "learned_embeds.bin" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" @@ -257,7 +257,7 @@ class UNet2DConditionLoadersMixin: use_safetensors = kwargs.pop("use_safetensors", None) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning - network_alpha = kwargs.pop("network_alpha", None) + network_alphas = kwargs.pop("network_alphas", None) if use_safetensors and not is_safetensors_available(): raise ValueError( @@ -322,7 +322,7 @@ class UNet2DConditionLoadersMixin: attn_processors = {} non_attn_lora_layers = [] - is_lora = all("lora" in k for k in state_dict.keys()) + is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) if is_lora: @@ -339,10 +339,25 @@ class UNet2DConditionLoadersMixin: state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} lora_grouped_dict = defaultdict(dict) - for key, value in state_dict.items(): + mapped_network_alphas = {} + + all_keys = list(state_dict.keys()) + for key in all_keys: + value = state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value + # Create another `mapped_network_alphas` dictionary so that we can properly map them. + if network_alphas is not None: + for k in network_alphas: + if k.replace(".alpha", "") in key: + mapped_network_alphas.update({attn_processor_key: network_alphas[k]}) + + if len(state_dict) > 0: + raise ValueError( + f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" + ) + for key, value_dict in lora_grouped_dict.items(): attn_processor = self for sub_key in key.split("."): @@ -352,13 +367,27 @@ class UNet2DConditionLoadersMixin: # or add_{k,v,q,out_proj}_proj_lora layers. if "lora.down.weight" in value_dict: rank = value_dict["lora.down.weight"].shape[0] - hidden_size = value_dict["lora.up.weight"].shape[0] if isinstance(attn_processor, LoRACompatibleConv): - lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha) + in_features = attn_processor.in_channels + out_features = attn_processor.out_channels + kernel_size = attn_processor.kernel_size + + lora = LoRAConv2dLayer( + in_features=in_features, + out_features=out_features, + rank=rank, + kernel_size=kernel_size, + stride=attn_processor.stride, + padding=attn_processor.padding, + network_alpha=mapped_network_alphas.get(key), + ) elif isinstance(attn_processor, LoRACompatibleLinear): lora = LoRALinearLayer( - attn_processor.in_features, attn_processor.out_features, rank, network_alpha + attn_processor.in_features, + attn_processor.out_features, + rank, + mapped_network_alphas.get(key), ) else: raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") @@ -366,32 +395,64 @@ class UNet2DConditionLoadersMixin: value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} lora.load_state_dict(value_dict) non_attn_lora_layers.append((attn_processor, lora)) - continue - - rank = value_dict["to_k_lora.down.weight"].shape[0] - hidden_size = value_dict["to_k_lora.up.weight"].shape[0] - - if isinstance( - attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) - ): - cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1] - attn_processor_class = LoRAAttnAddedKVProcessor else: - cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] - if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): - attn_processor_class = LoRAXFormersAttnProcessor + # To handle SDXL. + rank_mapping = {} + hidden_size_mapping = {} + for projection_id in ["to_k", "to_q", "to_v", "to_out"]: + rank = value_dict[f"{projection_id}_lora.down.weight"].shape[0] + hidden_size = value_dict[f"{projection_id}_lora.up.weight"].shape[0] + + rank_mapping.update({f"{projection_id}_lora.down.weight": rank}) + hidden_size_mapping.update({f"{projection_id}_lora.up.weight": hidden_size}) + + if isinstance( + attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) + ): + cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1] + attn_processor_class = LoRAAttnAddedKVProcessor else: - attn_processor_class = ( - LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): + attn_processor_class = LoRAXFormersAttnProcessor + else: + attn_processor_class = ( + LoRAAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else LoRAAttnProcessor + ) + + if attn_processor_class is not LoRAAttnAddedKVProcessor: + attn_processors[key] = attn_processor_class( + rank=rank_mapping.get("to_k_lora.down.weight"), + hidden_size=hidden_size_mapping.get("to_k_lora.up.weight"), + cross_attention_dim=cross_attention_dim, + network_alpha=mapped_network_alphas.get(key), + q_rank=rank_mapping.get("to_q_lora.down.weight"), + q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight"), + v_rank=rank_mapping.get("to_v_lora.down.weight"), + v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"), + out_rank=rank_mapping.get("to_out_lora.down.weight"), + out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"), + # rank=rank_mapping.get("to_k_lora.down.weight", None), + # hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None), + # q_rank=rank_mapping.get("to_q_lora.down.weight", None), + # q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None), + # v_rank=rank_mapping.get("to_v_lora.down.weight", None), + # v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None), + # out_rank=rank_mapping.get("to_out_lora.down.weight", None), + # out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None), + ) + else: + attn_processors[key] = attn_processor_class( + rank=rank_mapping.get("to_k_lora.down.weight", None), + hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None), + cross_attention_dim=cross_attention_dim, + network_alpha=mapped_network_alphas.get(key), ) - attn_processors[key] = attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=rank, - network_alpha=network_alpha, - ) - attn_processors[key].load_state_dict(value_dict) + attn_processors[key].load_state_dict(value_dict) + elif is_custom_diffusion: custom_diffusion_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): @@ -434,8 +495,10 @@ class UNet2DConditionLoadersMixin: # set ff layers for target_module, lora_layer in non_attn_lora_layers: - if hasattr(target_module, "set_lora_layer"): - target_module.set_lora_layer(lora_layer) + target_module.set_lora_layer(lora_layer) + # It should raise an error if we don't have a set lora here + # if hasattr(target_module, "set_lora_layer"): + # target_module.set_lora_layer(lora_layer) def save_attn_procs( self, @@ -873,11 +936,11 @@ class LoraLoaderMixin: kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ - state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) self.load_lora_into_text_encoder( state_dict, - network_alpha=network_alpha, + network_alphas=network_alphas, text_encoder=self.text_encoder, lora_scale=self.lora_scale, ) @@ -889,7 +952,7 @@ class LoraLoaderMixin: **kwargs, ): r""" - Return state dict for lora weights + Return state dict for lora weights and the network alphas. @@ -950,6 +1013,7 @@ class LoraLoaderMixin: revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): @@ -1011,53 +1075,158 @@ class LoraLoaderMixin: else: state_dict = pretrained_model_name_or_path_or_dict - # Convert kohya-ss Style LoRA attn procs to diffusers attn procs - network_alpha = None - if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): - state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict) + network_alphas = None + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict) - return state_dict, network_alpha + return state_dict, network_alphas @classmethod - def load_lora_into_unet(cls, state_dict, network_alpha, unet): + def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5): + is_all_unet = all(k.startswith("lora_unet") for k in state_dict) + new_state_dict = {} + inner_block_map = ["resnets", "attentions", "upsamplers"] + + # Retrieves # of down, mid and up blocks + input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() + for layer in state_dict: + if "text" not in layer: + layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) + if "input_blocks" in layer: + input_block_ids.add(layer_id) + elif "middle_block" in layer: + middle_block_ids.add(layer_id) + elif "output_blocks" in layer: + output_block_ids.add(layer_id) + else: + raise ValueError("Checkpoint not supported") + + input_blocks = { + layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] + for layer_id in input_block_ids + } + middle_blocks = { + layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] + for layer_id in middle_block_ids + } + output_blocks = { + layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key] + for layer_id in output_block_ids + } + + # Rename keys accordingly + for i in input_block_ids: + block_id = (i - 1) // (unet_config.layers_per_block + 1) + layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1) + + for key in input_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" + inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in middle_block_ids: + key_part = None + if i == 0: + key_part = [inner_block_map[0], "0"] + elif i == 1: + key_part = [inner_block_map[1], "0"] + elif i == 2: + key_part = [inner_block_map[0], "1"] + else: + raise ValueError(f"Invalid middle block id {i}.") + + for key in middle_blocks[i]: + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in output_block_ids: + block_id = i // (unet_config.layers_per_block + 1) + layer_in_block_id = i % (unet_config.layers_per_block + 1) + + for key in output_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] + inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + if is_all_unet and len(state_dict) > 0: + raise ValueError("At this point all state dict entries have to be converted.") + else: + # Remaining is the text encoder state dict. + for k, v in state_dict.items(): + new_state_dict.update({k: v}) + + return new_state_dict + + @classmethod + def load_lora_into_unet(cls, state_dict, network_alphas, unet): """ - This will load the LoRA layers specified in `state_dict` into `unet` + This will load the LoRA layers specified in `state_dict` into `unet`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. - network_alpha (`float`): + network_alphas (`Dict[str, float]`): See `LoRALinearLayer` for more details. unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. """ - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) + if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): # Load the layers corresponding to UNet. - unet_keys = [k for k in keys if k.startswith(cls.unet_name)] logger.info(f"Loading {cls.unet_name}.") - unet_lora_state_dict = { - k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys - } - unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) - # Otherwise, we're dealing with the old format. This means the `state_dict` should only - # contain the module names of the `unet` as its keys WITHOUT any prefix. - elif not all( - key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys() - ): - unet.load_attn_procs(state_dict, network_alpha=network_alpha) + unet_keys = [k for k in keys if k.startswith(cls.unet_name)] + state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)] + network_alphas = { + k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + else: + # Otherwise, we're dealing with the old format. This means the `state_dict` should only + # contain the module names of the `unet` as its keys WITHOUT any prefix. warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warnings.warn(warn_message) + # load loras into unet + unet.load_attn_procs(state_dict, network_alphas=network_alphas) + @classmethod - def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0): + def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1065,7 +1234,7 @@ class LoraLoaderMixin: state_dict (`dict`): A standard state dict containing the lora layer parameters. The key should be prefixed with an additional `text_encoder` to distinguish between unet lora layers. - network_alpha (`float`): + network_alphas (`Dict[str, float]`): See `LoRALinearLayer` for more details. text_encoder (`CLIPTextModel`): The text encoder model to load the LoRA layers into. @@ -1134,14 +1303,19 @@ class LoraLoaderMixin: ].shape[1] patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) - cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp) + cls._modify_text_encoder( + text_encoder, + lora_scale, + network_alphas, + rank=rank, + patch_mlp=patch_mlp, + ) # set correct dtype & device text_encoder_lora_state_dict = { k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) for k, v in text_encoder_lora_state_dict.items() } - load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) if len(load_state_dict_results.unexpected_keys) != 0: raise ValueError( @@ -1176,7 +1350,7 @@ class LoraLoaderMixin: cls, text_encoder, lora_scale=1, - network_alpha=None, + network_alphas=None, rank=4, dtype=None, patch_mlp=False, @@ -1189,37 +1363,46 @@ class LoraLoaderMixin: cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) lora_parameters = [] + network_alphas = {} if network_alphas is None else network_alphas + + for name, attn_module in text_encoder_attn_modules(text_encoder): + query_alpha = network_alphas.get(name + ".k.proj.alpha") + key_alpha = network_alphas.get(name + ".q.proj.alpha") + value_alpha = network_alphas.get(name + ".v.proj.alpha") + proj_alpha = network_alphas.get(name + ".out.proj.alpha") - for _, attn_module in text_encoder_attn_modules(text_encoder): attn_module.q_proj = PatchedLoraProjection( - attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) attn_module.k_proj = PatchedLoraProjection( - attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) attn_module.v_proj = PatchedLoraProjection( - attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) attn_module.out_proj = PatchedLoraProjection( - attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + attn_module.out_proj, lora_scale, network_alpha=proj_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) if patch_mlp: - for _, mlp_module in text_encoder_mlp_modules(text_encoder): + for name, mlp_module in text_encoder_mlp_modules(text_encoder): + fc1_alpha = network_alphas.get(name + ".fc1.alpha") + fc2_alpha = network_alphas.get(name + ".fc2.alpha") + mlp_module.fc1 = PatchedLoraProjection( - mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype + mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) mlp_module.fc2 = PatchedLoraProjection( - mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype + mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) @@ -1326,77 +1509,163 @@ class LoraLoaderMixin: def _convert_kohya_lora_to_diffusers(cls, state_dict): unet_state_dict = {} te_state_dict = {} - network_alpha = None - unloaded_keys = [] + te2_state_dict = {} + network_alphas = {} - for key, value in state_dict.items(): - if "hada" in key or "skip" in key: - unloaded_keys.append(key) - elif "lora_down" in key: - lora_name = key.split(".")[0] - lora_name_up = lora_name + ".lora_up.weight" - lora_name_alpha = lora_name + ".alpha" - if lora_name_alpha in state_dict: - alpha = state_dict[lora_name_alpha].item() - if network_alpha is None: - network_alpha = alpha - elif network_alpha != alpha: - raise ValueError("Network alpha is not consistent") + # every down weight has a corresponding up weight and potentially an alpha weight + lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] + for key in lora_keys: + lora_name = key.split(".")[0] + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" - if lora_name.startswith("lora_unet_"): - diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + # if lora_name_alpha in state_dict: + # alpha = state_dict.pop(lora_name_alpha).item() + # network_alphas.update({lora_name_alpha: alpha}) + + if lora_name.startswith("lora_unet_"): + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + + if "input.blocks" in diffusers_name: + diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") + else: diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + + if "middle.block" in diffusers_name: + diffusers_name = diffusers_name.replace("middle.block", "mid_block") + else: diffusers_name = diffusers_name.replace("mid.block", "mid_block") + if "output.blocks" in diffusers_name: + diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") + else: diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") - diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") - diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") - diffusers_name = diffusers_name.replace("proj.in", "proj_in") - diffusers_name = diffusers_name.replace("proj.out", "proj_out") - if "transformer_blocks" in diffusers_name: - if "attn1" in diffusers_name or "attn2" in diffusers_name: - diffusers_name = diffusers_name.replace("attn1", "attn1.processor") - diffusers_name = diffusers_name.replace("attn2", "attn2.processor") - unet_state_dict[diffusers_name] = value - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - elif "ff" in diffusers_name: - unet_state_dict[diffusers_name] = value - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - elif any(key in diffusers_name for key in ("proj_in", "proj_out")): - unet_state_dict[diffusers_name] = value - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - elif lora_name.startswith("lora_te_"): - diffusers_name = key.replace("lora_te_", "").replace("_", ".") - diffusers_name = diffusers_name.replace("text.model", "text_model") - diffusers_name = diffusers_name.replace("self.attn", "self_attn") - diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - if "self_attn" in diffusers_name: - te_state_dict[diffusers_name] = value - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - elif "mlp" in diffusers_name: - # Be aware that this is the new diffusers convention and the rest of the code might - # not utilize it yet. - diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") - te_state_dict[diffusers_name] = value - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") + diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") - logger.info("Kohya-style checkpoint detected.") - if len(unloaded_keys) > 0: - example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS]) - logger.warning( - f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for." + # SDXL specificity. + if "emb" in diffusers_name: + pattern = r"\.\d+(?=\D*$)" + diffusers_name = re.sub(pattern, "", diffusers_name, count=1) + if ".in." in diffusers_name: + diffusers_name = diffusers_name.replace("in.layers.2", "conv1") + if ".out." in diffusers_name: + diffusers_name = diffusers_name.replace("out.layers.3", "conv2") + if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: + diffusers_name = diffusers_name.replace("op", "conv") + if "skip" in diffusers_name: + diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") + + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "ff" in diffusers_name: + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + else: + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + elif lora_name.startswith("lora_te_"): + diffusers_name = key.replace("lora_te_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # (sayakpaul): Duplicate code. Needs to be cleaned. + elif lora_name.startswith("lora_te1_"): + diffusers_name = key.replace("lora_te1_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # (sayakpaul): Duplicate code. Needs to be cleaned. + elif lora_name.startswith("lora_te2_"): + diffusers_name = key.replace("lora_te2_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te2_state_dict[diffusers_name] = state_dict.pop(key) + te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + te2_state_dict[diffusers_name] = state_dict.pop(key) + te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # Rename the alphas so that they can be mapped appropriately. + if lora_name_alpha in state_dict: + alpha = state_dict.pop(lora_name_alpha).item() + if lora_name_alpha.startswith("lora_unet_"): + prefix = "unet." + elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): + prefix = "text_encoder." + else: + prefix = "text_encoder_2." + new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" + network_alphas.update({new_name: alpha}) + + if len(state_dict) > 0: + raise ValueError( + f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}" ) - unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} - te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} + logger.info("Kohya-style checkpoint detected.") + unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} + te_state_dict = { + f"{cls.text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items() + } + te2_state_dict = ( + {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()} + if len(te2_state_dict) > 0 + else None + ) + if te2_state_dict is not None: + te_state_dict.update(te2_state_dict) + new_state_dict = {**unet_state_dict, **te_state_dict} - return new_state_dict, network_alpha + return new_state_dict, network_alphas def unload_lora_weights(self): """ diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index de4adec042..43497c2284 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -521,17 +521,32 @@ class LoRAAttnProcessor(nn.Module): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ - def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + q_rank = kwargs.pop("q_rank", None) + q_hidden_size = kwargs.pop("q_hidden_size", None) + q_rank = q_rank if q_rank is not None else rank + q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size + + v_rank = kwargs.pop("v_rank", None) + v_hidden_size = kwargs.pop("v_hidden_size", None) + v_rank = v_rank if v_rank is not None else rank + v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size + + out_rank = kwargs.pop("out_rank", None) + out_hidden_size = kwargs.pop("out_hidden_size", None) + out_rank = out_rank if out_rank is not None else rank + out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size + + self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) + self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None @@ -1144,7 +1159,13 @@ class LoRAXFormersAttnProcessor(nn.Module): """ def __init__( - self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None + self, + hidden_size, + cross_attention_dim, + rank=4, + attention_op: Optional[Callable] = None, + network_alpha=None, + **kwargs, ): super().__init__() @@ -1153,10 +1174,25 @@ class LoRAXFormersAttnProcessor(nn.Module): self.rank = rank self.attention_op = attention_op - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + q_rank = kwargs.pop("q_rank", None) + q_hidden_size = kwargs.pop("q_hidden_size", None) + q_rank = q_rank if q_rank is not None else rank + q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size + + v_rank = kwargs.pop("v_rank", None) + v_hidden_size = kwargs.pop("v_hidden_size", None) + v_rank = v_rank if v_rank is not None else rank + v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size + + out_rank = kwargs.pop("out_rank", None) + out_hidden_size = kwargs.pop("out_hidden_size", None) + out_rank = out_rank if out_rank is not None else rank + out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size + + self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) + self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None @@ -1231,7 +1267,7 @@ class LoRAAttnProcessor2_0(nn.Module): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ - def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -1240,10 +1276,25 @@ class LoRAAttnProcessor2_0(nn.Module): self.cross_attention_dim = cross_attention_dim self.rank = rank - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + q_rank = kwargs.pop("q_rank", None) + q_hidden_size = kwargs.pop("q_hidden_size", None) + q_rank = q_rank if q_rank is not None else rank + q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size + + v_rank = kwargs.pop("v_rank", None) + v_hidden_size = kwargs.pop("v_hidden_size", None) + v_rank = v_rank if v_rank is not None else rank + v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size + + out_rank = kwargs.pop("out_rank", None) + out_hidden_size = kwargs.pop("out_hidden_size", None) + out_rank = out_rank if out_rank is not None else rank + out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size + + self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) + self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): residual = hidden_states diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index 7bc573bf72..171f1323cf 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -49,14 +49,19 @@ class LoRALinearLayer(nn.Module): class LoRAConv2dLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4, network_alpha=None): + def __init__( + self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None + ): super().__init__() if rank > min(in_features, out_features): raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") - self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False) - self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False) + self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + # according to the official kohya_ss trainer kernel_size are always fixed for the up layer + # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 + self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning self.network_alpha = network_alpha diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 24c3b07e7c..72aa17ed2c 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -23,6 +23,7 @@ import torch.nn.functional as F from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import SpatialNorm +from .lora import LoRACompatibleConv, LoRACompatibleLinear class Upsample1D(nn.Module): @@ -126,7 +127,7 @@ class Upsample2D(nn.Module): if use_conv_transpose: conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) elif use_conv: - conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": @@ -196,7 +197,7 @@ class Downsample2D(nn.Module): self.name = name if use_conv: - conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) @@ -534,13 +535,13 @@ class ResnetBlock2D(nn.Module): else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": - self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": - self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) + self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels) elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None else: @@ -557,7 +558,7 @@ class ResnetBlock2D(nn.Module): self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels - self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) @@ -583,7 +584,7 @@ class ResnetBlock2D(nn.Module): self.conv_shortcut = None if self.use_in_shortcut: - self.conv_shortcut = torch.nn.Conv2d( + self.conv_shortcut = LoRACompatibleConv( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index bbd93430da..998535c58a 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -23,7 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock from .embeddings import PatchEmbed -from .lora import LoRACompatibleConv +from .lora import LoRACompatibleConv, LoRACompatibleLinear from .modeling_utils import ModelMixin @@ -137,7 +137,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: - self.proj_in = nn.Linear(in_channels, inner_dim) + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) else: self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: @@ -193,7 +193,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): if self.is_input_continuous: # TODO: should use out_channels for continuous projections if use_linear_projection: - self.proj_out = nn.Linear(inner_dim, in_channels) + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) else: self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a48600f35d..01e78def6b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -88,11 +88,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] Args: vae ([`AutoencoderKL`]): @@ -866,14 +866,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad # Overrride to properly handle the loading and unloading of the additional text encoder. def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): - state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_state_dict, - network_alpha=network_alpha, + network_alphas=network_alphas, text_encoder=self.text_encoder, prefix="text_encoder", lora_scale=self.lora_scale, @@ -883,7 +890,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad if len(text_encoder_2_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_2_state_dict, - network_alpha=network_alpha, + network_alphas=network_alphas, text_encoder=self.text_encoder_2, prefix="text_encoder_2", lora_scale=self.lora_scale, diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 0d9031c9aa..000748312f 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -737,8 +737,7 @@ class LoraIntegrationTests(unittest.TestCase): ).images images = images[0, -3:, -3:, -1].flatten() - - expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + expected = np.array([0.3725, 0.3767, 0.3761, 0.3796, 0.3827, 0.3763, 0.3831, 0.3809, 0.3392]) self.assertTrue(np.allclose(images, expected, atol=1e-4))