1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA ]fix flux lora loader when return_metadata is true for non-diffusers (#11716)

* fix flux lora loader when return_metadata is true for non-diffusers

* remove annotation
This commit is contained in:
Sayak Paul
2025-06-16 14:26:35 +05:30
committed by GitHub
parent 8adc6003ba
commit d1db4f853a
3 changed files with 45 additions and 12 deletions

View File

@@ -2031,18 +2031,36 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
return (state_dict, None) if return_alphas else state_dict
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
# xlabs doesn't use `alpha`.
return (state_dict, None) if return_alphas else state_dict
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
is_bfl_control = any("query_norm.scale" in k for k in state_dict)
if is_bfl_control:
state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict)
return (state_dict, None) if return_alphas else state_dict
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
@@ -2061,12 +2079,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
)
if return_alphas or return_lora_metadata:
outputs = [state_dict]
if return_alphas:
outputs.append(network_alphas)
if return_lora_metadata:
outputs.append(metadata)
return tuple(outputs)
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=network_alphas,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
else:
return state_dict
@@ -2785,6 +2804,15 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.")
@staticmethod
def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False):
outputs = [state_dict]
if return_alphas:
outputs.append(alphas)
if return_metadata:
outputs.append(metadata)
return tuple(outputs) if (return_alphas or return_metadata) else state_dict
# The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially
# relied on `StableDiffusionLoraLoaderMixin` for its LoRA support.

View File

@@ -187,7 +187,9 @@ class PeftAdapterMixin:
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
metadata: TODO
metadata:
LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to
initialize `LoraConfig`.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer

View File

@@ -359,5 +359,8 @@ def _load_sft_state_dict_metadata(model_file: str):
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None
if metadata:
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None
else:
return None