mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] feat: support non-diffusers lumina2 LoRAs. (#10909)
* feat: support non-diffusers lumina2 LoRAs. * revert ipynb changes (but I don't know why this is required ☹️) * empty --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -1276,3 +1276,74 @@ def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
|
||||
# Remove "diffusion_model." prefix from keys.
|
||||
state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
|
||||
converted_state_dict = {}
|
||||
|
||||
def get_num_layers(keys, pattern):
|
||||
layers = set()
|
||||
for key in keys:
|
||||
match = re.search(pattern, key)
|
||||
if match:
|
||||
layers.add(int(match.group(1)))
|
||||
return len(layers)
|
||||
|
||||
def process_block(prefix, index, convert_norm):
|
||||
# Process attention qkv: pop lora_A and lora_B weights.
|
||||
lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
|
||||
lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
|
||||
for attn_key in ["to_q", "to_k", "to_v"]:
|
||||
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
|
||||
for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
|
||||
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight
|
||||
|
||||
# Process attention out weights.
|
||||
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.attention.out.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.attention.out.lora_B.weight"
|
||||
)
|
||||
|
||||
# Process feed-forward weights for layers 1, 2, and 3.
|
||||
for layer in range(1, 4):
|
||||
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
|
||||
)
|
||||
|
||||
if convert_norm:
|
||||
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
|
||||
)
|
||||
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
|
||||
f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
|
||||
)
|
||||
|
||||
noise_refiner_pattern = r"noise_refiner\.(\d+)\."
|
||||
num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
|
||||
for i in range(num_noise_refiner_layers):
|
||||
process_block("noise_refiner", i, convert_norm=True)
|
||||
|
||||
context_refiner_pattern = r"context_refiner\.(\d+)\."
|
||||
num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
|
||||
for i in range(num_context_refiner_layers):
|
||||
process_block("context_refiner", i, convert_norm=False)
|
||||
|
||||
core_transformer_pattern = r"layers\.(\d+)\."
|
||||
num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
|
||||
for i in range(num_core_transformer_layers):
|
||||
process_block("layers", i, convert_norm=True)
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -41,6 +41,7 @@ from .lora_conversion_utils import (
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
)
|
||||
@@ -3815,7 +3816,6 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -3909,6 +3909,11 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
# conversion.
|
||||
non_diffusers = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if non_diffusers:
|
||||
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
|
||||
Reference in New Issue
Block a user