From d1db4f853a8d5da0a4bc4112010bca8d900871ef Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Jun 2025 14:26:35 +0530 Subject: [PATCH] [LoRA ]fix flux lora loader when return_metadata is true for non-diffusers (#11716) * fix flux lora loader when return_metadata is true for non-diffusers * remove annotation --- src/diffusers/loaders/lora_pipeline.py | 46 ++++++++++++++++++++----- src/diffusers/loaders/peft.py | 4 ++- src/diffusers/utils/state_dict_utils.py | 7 ++-- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 27053623ee..8fdd8a88ed 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -2031,18 +2031,36 @@ class FluxLoraLoaderMixin(LoraBaseMixin): if is_kohya: state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict) # Kohya already takes care of scaling the LoRA parameters with alpha. - return (state_dict, None) if return_alphas else state_dict + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) is_xlabs = any("processor" in k for k in state_dict) if is_xlabs: state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict) # xlabs doesn't use `alpha`. - return (state_dict, None) if return_alphas else state_dict + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) is_bfl_control = any("query_norm.scale" in k for k in state_dict) if is_bfl_control: state_dict = _convert_bfl_flux_control_lora_to_diffusers(state_dict) - return (state_dict, None) if return_alphas else state_dict + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=None, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) # For state dicts like # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA @@ -2061,12 +2079,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ) if return_alphas or return_lora_metadata: - outputs = [state_dict] - if return_alphas: - outputs.append(network_alphas) - if return_lora_metadata: - outputs.append(metadata) - return tuple(outputs) + return cls._prepare_outputs( + state_dict, + metadata=metadata, + alphas=network_alphas, + return_alphas=return_alphas, + return_metadata=return_lora_metadata, + ) else: return state_dict @@ -2785,6 +2804,15 @@ class FluxLoraLoaderMixin(LoraBaseMixin): raise ValueError("Either `base_module` or `base_weight_param_name` must be provided.") + @staticmethod + def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, return_metadata=False): + outputs = [state_dict] + if return_alphas: + outputs.append(alphas) + if return_metadata: + outputs.append(metadata) + return tuple(outputs) if (return_alphas or return_metadata) else state_dict + # The reason why we subclass from `StableDiffusionLoraLoaderMixin` here is because Amused initially # relied on `StableDiffusionLoraLoaderMixin` for its LoRA support. diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index e7a458f28e..6bb6e36936 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -187,7 +187,9 @@ class PeftAdapterMixin: Note that hotswapping adapters of the text encoder is not yet supported. There are some further limitations to this technique, which are documented here: https://huggingface.co/docs/peft/main/en/package_reference/hotswap - metadata: TODO + metadata: + LoRA adapter metadata. When supplied, the metadata inferred through the state dict isn't used to + initialize `LoraConfig`. """ from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft.tuners.tuners_utils import BaseTunerLayer diff --git a/src/diffusers/utils/state_dict_utils.py b/src/diffusers/utils/state_dict_utils.py index 498f7e566c..8e6078488a 100644 --- a/src/diffusers/utils/state_dict_utils.py +++ b/src/diffusers/utils/state_dict_utils.py @@ -359,5 +359,8 @@ def _load_sft_state_dict_metadata(model_file: str): metadata = f.metadata() or {} metadata.pop("format", None) - raw = metadata.get(LORA_ADAPTER_METADATA_KEY) - return json.loads(raw) if raw else None + if metadata: + raw = metadata.get(LORA_ADAPTER_METADATA_KEY) + return json.loads(raw) if raw else None + else: + return None