1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Feat] Support SDXL Kohya-style LoRA (#4287)

* sdxl lora changes.

* better name replacement.

* better replacement.

* debugging

* debugging

* debugging

* debugging

* debugging

* remove print.

* print state dict keys.

* print

* distingisuih better

* debuggable.

* fxi: tyests

* fix: arg from training script.

* access from class.

* run style

* debug

* save intermediate

* some simplifications for SDXL LoRA

* styling

* unet config is not needed in diffusers format.

* fix: dynamic SGM block mapping for SDXL kohya loras (#4322)

* Use lora compatible layers for linear proj_in/proj_out (#4323)

* improve condition for using the sgm_diffusers mapping

* informative comment.

* load compatible keys and embedding layer maaping.

* Get SDXL 1.0 example lora to load

* simplify

* specif ranks and hidden sizes.

* better handling of k rank and hidden

* debug

* debug

* debug

* debug

* debug

* fix: alpha keys

* add check for handling LoRAAttnAddedKVProcessor

* sanity comment

* modifications for text encoder SDXL

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* denugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* up

* up

* up

* up

* up

* up

* unneeded comments.

* unneeded comments.

* kwargs for the other attention processors.

* kwargs for the other attention processors.

* debugging

* debugging

* debugging

* debugging

* improve

* debugging

* debugging

* more print

* Fix alphas

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* clean up

* clean up.

* debugging

* fix: text

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Batuhan Taskaya <batuhan@python.org>
This commit is contained in:
Sayak Paul
2023-07-28 23:19:49 +05:30
committed by GitHub
parent b7b6d6138d
commit 4a4cdd6b07
10 changed files with 550 additions and 170 deletions

View File

@@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:
lora_model_id = "sayakpaul/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
```
```
### Supporting Stable Diffusion XL LoRAs trained using the Kohya-trainer
With this [PR](https://github.com/huggingface/diffusers/pull/4287), there should now be better support for loading Kohya-style LoRAs trained on Stable Diffusion XL (SDXL).
Here are some example checkpoints we tried out:
* SDXL 0.9:
* https://civitai.com/models/22279?modelVersionId=118556
* https://civitai.com/models/104515/sdxlor30costumesrevue-starlight-saijoclaudine-lora
* https://civitai.com/models/108448/daiton-sdxl-test
* https://filebin.net/2ntfqqnapiu9q3zx/pixelbuildings128-v1.safetensors
* SDXL 1.0:
* https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors
Here is an example of how to perform inference with these checkpoints in `diffusers`:
```python
from diffusers import DiffusionPipeline
import torch
base_model_id = "stabilityai/stable-diffusion-xl-base-0.9"
pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
pipeline.load_lora_weights(".", weight_name="Kamepan.safetensors")
prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, offical wallpaper, glint <lora:kame_sdxl_v2:1>"
negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad propotions"
generator = torch.manual_seed(2947883060)
num_inference_steps = 30
guidance_scale = 7
image = pipeline(
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps,
generator=generator, guidance_scale=guidance_scale
).images[0]
image.save("Kamepan.png")
```
`Kamepan.safetensors` comes from https://civitai.com/models/22279?modelVersionId=118556 .
If you notice carefully, the inference UX is exactly identical to what we presented in the sections above.
Thanks to [@isidentical](https://github.com/isidentical) for helping us on integrating this feature.
### Known limitations specific to the Kohya-styled LoRAs
* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue.
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).

View File

@@ -925,10 +925,10 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
)
accelerator.register_save_state_pre_hook(save_model_hook)

View File

@@ -825,13 +825,13 @@ def main(args):
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)
accelerator.register_save_state_pre_hook(save_model_hook)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import warnings
from collections import defaultdict
from contextlib import nullcontext
@@ -56,7 +57,6 @@ UNET_NAME = "unet"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
TOTAL_EXAMPLE_KEYS = 5
TEXT_INVERSION_NAME = "learned_embeds.bin"
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
@@ -257,7 +257,7 @@ class UNet2DConditionLoadersMixin:
use_safetensors = kwargs.pop("use_safetensors", None)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alpha = kwargs.pop("network_alpha", None)
network_alphas = kwargs.pop("network_alphas", None)
if use_safetensors and not is_safetensors_available():
raise ValueError(
@@ -322,7 +322,7 @@ class UNet2DConditionLoadersMixin:
attn_processors = {}
non_attn_lora_layers = []
is_lora = all("lora" in k for k in state_dict.keys())
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
if is_lora:
@@ -339,10 +339,25 @@ class UNet2DConditionLoadersMixin:
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
lora_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
mapped_network_alphas = {}
all_keys = list(state_dict.keys())
for key in all_keys:
value = state_dict.pop(key)
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
if network_alphas is not None:
for k in network_alphas:
if k.replace(".alpha", "") in key:
mapped_network_alphas.update({attn_processor_key: network_alphas[k]})
if len(state_dict) > 0:
raise ValueError(
f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
)
for key, value_dict in lora_grouped_dict.items():
attn_processor = self
for sub_key in key.split("."):
@@ -352,13 +367,27 @@ class UNet2DConditionLoadersMixin:
# or add_{k,v,q,out_proj}_proj_lora layers.
if "lora.down.weight" in value_dict:
rank = value_dict["lora.down.weight"].shape[0]
hidden_size = value_dict["lora.up.weight"].shape[0]
if isinstance(attn_processor, LoRACompatibleConv):
lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha)
in_features = attn_processor.in_channels
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size
lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features, attn_processor.out_features, rank, network_alpha
attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
@@ -366,32 +395,64 @@ class UNet2DConditionLoadersMixin:
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora.load_state_dict(value_dict)
non_attn_lora_layers.append((attn_processor, lora))
continue
rank = value_dict["to_k_lora.down.weight"].shape[0]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnAddedKVProcessor
else:
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
attn_processor_class = LoRAXFormersAttnProcessor
# To handle SDXL.
rank_mapping = {}
hidden_size_mapping = {}
for projection_id in ["to_k", "to_q", "to_v", "to_out"]:
rank = value_dict[f"{projection_id}_lora.down.weight"].shape[0]
hidden_size = value_dict[f"{projection_id}_lora.up.weight"].shape[0]
rank_mapping.update({f"{projection_id}_lora.down.weight": rank})
hidden_size_mapping.update({f"{projection_id}_lora.up.weight": hidden_size})
if isinstance(
attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)
):
cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1]
attn_processor_class = LoRAAttnAddedKVProcessor
else:
attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)):
attn_processor_class = LoRAXFormersAttnProcessor
else:
attn_processor_class = (
LoRAAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention")
else LoRAAttnProcessor
)
if attn_processor_class is not LoRAAttnAddedKVProcessor:
attn_processors[key] = attn_processor_class(
rank=rank_mapping.get("to_k_lora.down.weight"),
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight"),
cross_attention_dim=cross_attention_dim,
network_alpha=mapped_network_alphas.get(key),
q_rank=rank_mapping.get("to_q_lora.down.weight"),
q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight"),
v_rank=rank_mapping.get("to_v_lora.down.weight"),
v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"),
out_rank=rank_mapping.get("to_out_lora.down.weight"),
out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"),
# rank=rank_mapping.get("to_k_lora.down.weight", None),
# hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
# q_rank=rank_mapping.get("to_q_lora.down.weight", None),
# q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None),
# v_rank=rank_mapping.get("to_v_lora.down.weight", None),
# v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None),
# out_rank=rank_mapping.get("to_out_lora.down.weight", None),
# out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None),
)
else:
attn_processors[key] = attn_processor_class(
rank=rank_mapping.get("to_k_lora.down.weight", None),
hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None),
cross_attention_dim=cross_attention_dim,
network_alpha=mapped_network_alphas.get(key),
)
attn_processors[key] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=rank,
network_alpha=network_alpha,
)
attn_processors[key].load_state_dict(value_dict)
attn_processors[key].load_state_dict(value_dict)
elif is_custom_diffusion:
custom_diffusion_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
@@ -434,8 +495,10 @@ class UNet2DConditionLoadersMixin:
# set ff layers
for target_module, lora_layer in non_attn_lora_layers:
if hasattr(target_module, "set_lora_layer"):
target_module.set_lora_layer(lora_layer)
target_module.set_lora_layer(lora_layer)
# It should raise an error if we don't have a set lora here
# if hasattr(target_module, "set_lora_layer"):
# target_module.set_lora_layer(lora_layer)
def save_attn_procs(
self,
@@ -873,11 +936,11 @@ class LoraLoaderMixin:
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
self.load_lora_into_text_encoder(
state_dict,
network_alpha=network_alpha,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
)
@@ -889,7 +952,7 @@ class LoraLoaderMixin:
**kwargs,
):
r"""
Return state dict for lora weights
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
@@ -950,6 +1013,7 @@ class LoraLoaderMixin:
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None)
if use_safetensors and not is_safetensors_available():
@@ -1011,53 +1075,158 @@ class LoraLoaderMixin:
else:
state_dict = pretrained_model_name_or_path_or_dict
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
network_alpha = None
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict)
network_alphas = None
if all(
(
k.startswith("lora_te_")
or k.startswith("lora_unet_")
or k.startswith("lora_te1_")
or k.startswith("lora_te2_")
)
for k in state_dict.keys()
):
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
return state_dict, network_alpha
return state_dict, network_alphas
@classmethod
def load_lora_into_unet(cls, state_dict, network_alpha, unet):
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
new_state_dict = {}
inner_block_map = ["resnets", "attentions", "upsamplers"]
# Retrieves # of down, mid and up blocks
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
for layer in state_dict:
if "text" not in layer:
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
if "input_blocks" in layer:
input_block_ids.add(layer_id)
elif "middle_block" in layer:
middle_block_ids.add(layer_id)
elif "output_blocks" in layer:
output_block_ids.add(layer_id)
else:
raise ValueError("Checkpoint not supported")
input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
for layer_id in input_block_ids
}
middle_blocks = {
layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
for layer_id in middle_block_ids
}
output_blocks = {
layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
for layer_id in output_block_ids
}
# Rename keys accordingly
for i in input_block_ids:
block_id = (i - 1) // (unet_config.layers_per_block + 1)
layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
for key in input_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in middle_block_ids:
key_part = None
if i == 0:
key_part = [inner_block_map[0], "0"]
elif i == 1:
key_part = [inner_block_map[1], "0"]
elif i == 2:
key_part = [inner_block_map[0], "1"]
else:
raise ValueError(f"Invalid middle block id {i}.")
for key in middle_blocks[i]:
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in output_block_ids:
block_id = i // (unet_config.layers_per_block + 1)
layer_in_block_id = i % (unet_config.layers_per_block + 1)
for key in output_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id]
inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
if is_all_unet and len(state_dict) > 0:
raise ValueError("At this point all state dict entries have to be converted.")
else:
# Remaining is the text encoder state dict.
for k, v in state_dict.items():
new_state_dict.update({k: v})
return new_state_dict
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet):
"""
This will load the LoRA layers specified in `state_dict` into `unet`
This will load the LoRA layers specified in `state_dict` into `unet`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alpha (`float`):
network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
"""
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
logger.info(f"Loading {cls.unet_name}.")
unet_lora_state_dict = {
k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
elif not all(
key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys()
):
unet.load_attn_procs(state_dict, network_alpha=network_alpha)
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
network_alphas = {
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
else:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message)
# load loras into unet
unet.load_attn_procs(state_dict, network_alphas=network_alphas)
@classmethod
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0):
def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1065,7 +1234,7 @@ class LoraLoaderMixin:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alpha (`float`):
network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
@@ -1134,14 +1303,19 @@ class LoraLoaderMixin:
].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp)
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
)
# set correct dtype & device
text_encoder_lora_state_dict = {
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
for k, v in text_encoder_lora_state_dict.items()
}
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
if len(load_state_dict_results.unexpected_keys) != 0:
raise ValueError(
@@ -1176,7 +1350,7 @@ class LoraLoaderMixin:
cls,
text_encoder,
lora_scale=1,
network_alpha=None,
network_alphas=None,
rank=4,
dtype=None,
patch_mlp=False,
@@ -1189,37 +1363,46 @@ class LoraLoaderMixin:
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
lora_parameters = []
network_alphas = {} if network_alphas is None else network_alphas
for name, attn_module in text_encoder_attn_modules(text_encoder):
query_alpha = network_alphas.get(name + ".k.proj.alpha")
key_alpha = network_alphas.get(name + ".q.proj.alpha")
value_alpha = network_alphas.get(name + ".v.proj.alpha")
proj_alpha = network_alphas.get(name + ".out.proj.alpha")
for _, attn_module in text_encoder_attn_modules(text_encoder):
attn_module.q_proj = PatchedLoraProjection(
attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
attn_module.k_proj = PatchedLoraProjection(
attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
attn_module.v_proj = PatchedLoraProjection(
attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
attn_module.out_proj = PatchedLoraProjection(
attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype
attn_module.out_proj, lora_scale, network_alpha=proj_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
if patch_mlp:
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
fc1_alpha = network_alphas.get(name + ".fc1.alpha")
fc2_alpha = network_alphas.get(name + ".fc2.alpha")
mlp_module.fc1 = PatchedLoraProjection(
mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype
mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
mlp_module.fc2 = PatchedLoraProjection(
mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype
mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype
)
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
@@ -1326,77 +1509,163 @@ class LoraLoaderMixin:
def _convert_kohya_lora_to_diffusers(cls, state_dict):
unet_state_dict = {}
te_state_dict = {}
network_alpha = None
unloaded_keys = []
te2_state_dict = {}
network_alphas = {}
for key, value in state_dict.items():
if "hada" in key or "skip" in key:
unloaded_keys.append(key)
elif "lora_down" in key:
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
if lora_name_alpha in state_dict:
alpha = state_dict[lora_name_alpha].item()
if network_alpha is None:
network_alpha = alpha
elif network_alpha != alpha:
raise ValueError("Network alpha is not consistent")
# every down weight has a corresponding up weight and potentially an alpha weight
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
for key in lora_keys:
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
# if lora_name_alpha in state_dict:
# alpha = state_dict.pop(lora_name_alpha).item()
# network_alphas.update({lora_name_alpha: alpha})
if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
if "input.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
else:
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
if "middle.block" in diffusers_name:
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
else:
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
if "output.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
else:
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
logger.info("Kohya-style checkpoint detected.")
if len(unloaded_keys) > 0:
example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS])
logger.warning(
f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for."
# SDXL specificity.
if "emb" in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
if ".out." in diffusers_name:
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
diffusers_name = diffusers_name.replace("op", "conv")
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif lora_name.startswith("lora_te1_"):
diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif lora_name.startswith("lora_te2_"):
diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Rename the alphas so that they can be mapped appropriately.
if lora_name_alpha in state_dict:
alpha = state_dict.pop(lora_name_alpha).item()
if lora_name_alpha.startswith("lora_unet_"):
prefix = "unet."
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
prefix = "text_encoder."
else:
prefix = "text_encoder_2."
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
network_alphas.update({new_name: alpha})
if len(state_dict) > 0:
raise ValueError(
f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}"
)
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
logger.info("Kohya-style checkpoint detected.")
unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {
f"{cls.text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()
}
te2_state_dict = (
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
if len(te2_state_dict) > 0
else None
)
if te2_state_dict is not None:
te_state_dict.update(te2_state_dict)
new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alpha
return new_state_dict, network_alphas
def unload_lora_weights(self):
"""

View File

@@ -521,17 +521,32 @@ class LoRAAttnProcessor(nn.Module):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
"""
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
@@ -1144,7 +1159,13 @@ class LoRAXFormersAttnProcessor(nn.Module):
"""
def __init__(
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
self,
hidden_size,
cross_attention_dim,
rank=4,
attention_op: Optional[Callable] = None,
network_alpha=None,
**kwargs,
):
super().__init__()
@@ -1153,10 +1174,25 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.rank = rank
self.attention_op = attention_op
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
@@ -1231,7 +1267,7 @@ class LoRAAttnProcessor2_0(nn.Module):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
"""
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
@@ -1240,10 +1276,25 @@ class LoRAAttnProcessor2_0(nn.Module):
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states

View File

@@ -49,14 +49,19 @@ class LoRALinearLayer(nn.Module):
class LoRAConv2dLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
def __init__(
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False)
self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False)
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha

View File

@@ -23,6 +23,7 @@ import torch.nn.functional as F
from .activations import get_activation
from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm
from .lora import LoRACompatibleConv, LoRACompatibleLinear
class Upsample1D(nn.Module):
@@ -126,7 +127,7 @@ class Upsample2D(nn.Module):
if use_conv_transpose:
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
@@ -196,7 +197,7 @@ class Downsample2D(nn.Module):
self.name = name
if use_conv:
conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
@@ -534,13 +535,13 @@ class ResnetBlock2D(nn.Module):
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None
else:
@@ -557,7 +558,7 @@ class ResnetBlock2D(nn.Module):
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity)
@@ -583,7 +584,7 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
self.conv_shortcut = LoRACompatibleConv(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
)

View File

@@ -23,7 +23,7 @@ from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput, deprecate
from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed
from .lora import LoRACompatibleConv
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin
@@ -137,7 +137,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
else:
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
@@ -193,7 +193,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = nn.Linear(inner_dim, in_channels)
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
else:
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:

View File

@@ -88,11 +88,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
as well as the following saving methods:
- *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`]
- *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
Args:
vae ([`AutoencoderKL`]):
@@ -866,14 +866,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alpha=network_alpha,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
@@ -883,7 +890,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alpha=network_alpha,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,

View File

@@ -737,8 +737,7 @@ class LoraIntegrationTests(unittest.TestCase):
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292])
expected = np.array([0.3725, 0.3767, 0.3761, 0.3796, 0.3827, 0.3763, 0.3831, 0.3809, 0.3392])
self.assertTrue(np.allclose(images, expected, atol=1e-4))