mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] support kohya and xlabs loras for flux. (#9295)
* support kohya lora in flux. * format * support xlabs * diffusion_model prefix. * Apply suggestions from code review Co-authored-by: apolinário <joaopaulo.passos@gmail.com> * empty commit. Co-authored-by: Leommm-byte <leom20031@gmail.com> --------- Co-authored-by: apolinário <joaopaulo.passos@gmail.com> Co-authored-by: Leommm-byte <leom20031@gmail.com>
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import is_peft_version, logging
|
||||
|
||||
|
||||
@@ -326,3 +328,294 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
|
||||
prefix = "text_encoder_2."
|
||||
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
||||
return {new_name: alpha}
|
||||
|
||||
|
||||
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
|
||||
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
||||
# All credits go to `kohya-ss`.
|
||||
def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
|
||||
# scale weight by alpha and dim
|
||||
rank = down_weight.shape[0]
|
||||
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
|
||||
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
|
||||
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
|
||||
|
||||
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
||||
if sds_key + ".lora_down.weight" not in sds_sd:
|
||||
return
|
||||
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
|
||||
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
|
||||
sd_lora_rank = down_weight.shape[0]
|
||||
|
||||
# scale weight by alpha and dim
|
||||
alpha = sds_sd.pop(sds_key + ".alpha")
|
||||
scale = alpha / sd_lora_rank
|
||||
|
||||
# calculate scale_down and scale_up
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
down_weight = down_weight * scale_down
|
||||
up_weight = up_weight * scale_up
|
||||
|
||||
# calculate dims if not provided
|
||||
num_splits = len(ait_keys)
|
||||
if dims is None:
|
||||
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||
else:
|
||||
assert sum(dims) == up_weight.shape[0]
|
||||
|
||||
# check upweight is sparse or not
|
||||
is_sparse = False
|
||||
if sd_lora_rank % num_splits == 0:
|
||||
ait_rank = sd_lora_rank // num_splits
|
||||
is_sparse = True
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
for k in range(len(dims)):
|
||||
if j == k:
|
||||
continue
|
||||
is_sparse = is_sparse and torch.all(
|
||||
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
|
||||
)
|
||||
i += dims[j]
|
||||
if is_sparse:
|
||||
logger.info(f"weight is sparse: {sds_key}")
|
||||
|
||||
# make ai-toolkit weight
|
||||
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
if not is_sparse:
|
||||
# down_weight is copied to each split
|
||||
ait_sd.update({k: down_weight for k in ait_down_keys})
|
||||
|
||||
# up_weight is split to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
||||
else:
|
||||
# down_weight is chunked to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
|
||||
|
||||
# up_weight is sparse: only non-zero values are copied to each split
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
|
||||
i += dims[j]
|
||||
|
||||
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
|
||||
ait_sd = {}
|
||||
for i in range(19):
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_out.0",
|
||||
)
|
||||
_convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_v",
|
||||
],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mlp_0",
|
||||
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mlp_2",
|
||||
f"transformer.transformer_blocks.{i}.ff.net.2",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_img_mod_lin",
|
||||
f"transformer.transformer_blocks.{i}.norm1.linear",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.to_add_out",
|
||||
)
|
||||
_convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
|
||||
[
|
||||
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
|
||||
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
|
||||
],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mlp_0",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mlp_2",
|
||||
f"transformer.transformer_blocks.{i}.ff_context.net.2",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_double_blocks_{i}_txt_mod_lin",
|
||||
f"transformer.transformer_blocks.{i}.norm1_context.linear",
|
||||
)
|
||||
|
||||
for i in range(38):
|
||||
_convert_to_ai_toolkit_cat(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear1",
|
||||
[
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_q",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_k",
|
||||
f"transformer.single_transformer_blocks.{i}.attn.to_v",
|
||||
f"transformer.single_transformer_blocks.{i}.proj_mlp",
|
||||
],
|
||||
dims=[3072, 3072, 3072, 12288],
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_linear2",
|
||||
f"transformer.single_transformer_blocks.{i}.proj_out",
|
||||
)
|
||||
_convert_to_ai_toolkit(
|
||||
sds_sd,
|
||||
ait_sd,
|
||||
f"lora_unet_single_blocks_{i}_modulation_lin",
|
||||
f"transformer.single_transformer_blocks.{i}.norm.linear",
|
||||
)
|
||||
|
||||
if len(sds_sd) > 0:
|
||||
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
|
||||
|
||||
return ait_sd
|
||||
|
||||
return _convert_sd_scripts_to_ai_toolkit(state_dict)
|
||||
|
||||
|
||||
# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
|
||||
# Some utilities were reused from
|
||||
# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
|
||||
def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
new_state_dict = {}
|
||||
orig_keys = list(old_state_dict.keys())
|
||||
|
||||
def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
|
||||
down_weight = sds_sd.pop(sds_key)
|
||||
up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
|
||||
|
||||
# calculate dims if not provided
|
||||
num_splits = len(ait_keys)
|
||||
if dims is None:
|
||||
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||
else:
|
||||
assert sum(dims) == up_weight.shape[0]
|
||||
|
||||
# make ai-toolkit weight
|
||||
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
|
||||
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
|
||||
|
||||
# down_weight is copied to each split
|
||||
ait_sd.update({k: down_weight for k in ait_down_keys})
|
||||
|
||||
# up_weight is split to each split
|
||||
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
|
||||
|
||||
for old_key in orig_keys:
|
||||
# Handle double_blocks
|
||||
if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
|
||||
block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
|
||||
new_key = f"transformer.transformer_blocks.{block_num}"
|
||||
|
||||
if "processor.proj_lora1" in old_key:
|
||||
new_key += ".attn.to_out.0"
|
||||
elif "processor.proj_lora2" in old_key:
|
||||
new_key += ".attn.to_add_out"
|
||||
elif "processor.qkv_lora1" in old_key and "up" not in old_key:
|
||||
handle_qkv(
|
||||
old_state_dict,
|
||||
new_state_dict,
|
||||
old_key,
|
||||
[
|
||||
f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
|
||||
f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
|
||||
f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
|
||||
],
|
||||
)
|
||||
# continue
|
||||
elif "processor.qkv_lora2" in old_key and "up" not in old_key:
|
||||
handle_qkv(
|
||||
old_state_dict,
|
||||
new_state_dict,
|
||||
old_key,
|
||||
[
|
||||
f"transformer.transformer_blocks.{block_num}.attn.to_q",
|
||||
f"transformer.transformer_blocks.{block_num}.attn.to_k",
|
||||
f"transformer.transformer_blocks.{block_num}.attn.to_v",
|
||||
],
|
||||
)
|
||||
# continue
|
||||
|
||||
if "down" in old_key:
|
||||
new_key += ".lora_A.weight"
|
||||
elif "up" in old_key:
|
||||
new_key += ".lora_B.weight"
|
||||
|
||||
# Handle single_blocks
|
||||
elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"):
|
||||
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
|
||||
new_key = f"transformer.single_transformer_blocks.{block_num}"
|
||||
|
||||
if "proj_lora1" in old_key or "proj_lora2" in old_key:
|
||||
new_key += ".proj_out"
|
||||
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
|
||||
new_key += ".norm.linear"
|
||||
|
||||
if "down" in old_key:
|
||||
new_key += ".lora_A.weight"
|
||||
elif "up" in old_key:
|
||||
new_key += ".lora_B.weight"
|
||||
|
||||
else:
|
||||
# Handle other potential key patterns here
|
||||
new_key = old_key
|
||||
|
||||
# Since we already handle qkv above.
|
||||
if "qkv" not in old_key:
|
||||
new_state_dict[new_key] = old_state_dict.pop(old_key)
|
||||
|
||||
if len(old_state_dict) > 0:
|
||||
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@@ -31,7 +31,12 @@ from ..utils import (
|
||||
scale_lora_layers,
|
||||
)
|
||||
from .lora_base import LoraBaseMixin
|
||||
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
|
||||
from .lora_conversion_utils import (
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -1583,6 +1588,20 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
|
||||
|
||||
is_kohya = any(".lora_down.weight" in k for k in state_dict)
|
||||
if is_kohya:
|
||||
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
|
||||
# Kohya already takes care of scaling the LoRA parameters with alpha.
|
||||
return (state_dict, None) if return_alphas else state_dict
|
||||
|
||||
is_xlabs = any("processor" in k for k in state_dict)
|
||||
if is_xlabs:
|
||||
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
|
||||
# xlabs doesn't use `alpha`.
|
||||
return (state_dict, None) if return_alphas else state_dict
|
||||
|
||||
# For state dicts like
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
Reference in New Issue
Block a user