mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Flux Kontext] Support Fal Kontext LoRA (#11823)
* initial commit * initial commit * initial commit * fix import * fix prefix * remove print * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1346,6 +1346,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
original_state_dict_keys = list(original_state_dict.keys())
|
||||
num_layers = 19
|
||||
num_single_layers = 38
|
||||
inner_dim = 3072
|
||||
mlp_ratio = 4.0
|
||||
|
||||
# double transformer blocks
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
original_block_prefix = "base_model.model."
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
# norms
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
|
||||
)
|
||||
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
|
||||
)
|
||||
|
||||
# Q, K, V
|
||||
if lora_key == "lora_A":
|
||||
sample_lora_weight = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
||||
|
||||
context_lora_weight = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
|
||||
[context_lora_weight]
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
|
||||
[context_lora_weight]
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
|
||||
[context_lora_weight]
|
||||
)
|
||||
else:
|
||||
sample_q, sample_k, sample_v = torch.chunk(
|
||||
original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
|
||||
),
|
||||
3,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
|
||||
|
||||
context_q, context_k, context_v = torch.chunk(
|
||||
original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
|
||||
),
|
||||
3,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
|
||||
|
||||
if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
||||
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
||||
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
|
||||
3,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
|
||||
|
||||
if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
||||
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
||||
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
|
||||
3,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
|
||||
|
||||
# ff img_mlp
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# single transformer blocks
|
||||
for i in range(num_single_layers):
|
||||
block_prefix = f"single_transformer_blocks.{i}."
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
# norm.linear <- single_blocks.0.modulation.lin
|
||||
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# Q, K, V, mlp
|
||||
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
||||
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
||||
|
||||
if lora_key == "lora_A":
|
||||
lora_weight = original_state_dict.pop(
|
||||
f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
|
||||
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
||||
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
else:
|
||||
q, k, v, mlp = torch.split(
|
||||
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
|
||||
split_size,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
|
||||
|
||||
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
||||
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
|
||||
split_size,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
|
||||
)
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
|
||||
)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from .lora_base import ( # noqa
|
||||
)
|
||||
from .lora_conversion_utils import (
|
||||
_convert_bfl_flux_control_lora_to_diffusers,
|
||||
_convert_fal_kontext_lora_to_diffusers,
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_musubi_wan_lora_to_diffusers,
|
||||
@@ -2062,6 +2063,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
return_metadata=return_lora_metadata,
|
||||
)
|
||||
|
||||
is_fal_kontext = any("base_model" in k for k in state_dict)
|
||||
if is_fal_kontext:
|
||||
state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
|
||||
return cls._prepare_outputs(
|
||||
state_dict,
|
||||
metadata=metadata,
|
||||
alphas=None,
|
||||
return_alphas=return_alphas,
|
||||
return_metadata=return_lora_metadata,
|
||||
)
|
||||
|
||||
# For state dicts like
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
Reference in New Issue
Block a user