From d8310a8fca812ff4afdcc1fd09c135dcccaab670 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 22 Sep 2025 15:14:39 +0530 Subject: [PATCH] [lora] factor out the overlaps in `save_lora_weights()`. (#12027) * factor out the overlaps in save_lora_weights(). * remove comment. * remove comment. * up * fix-copies --- src/diffusers/loaders/lora_base.py | 35 ++ src/diffusers/loaders/lora_pipeline.py | 426 ++++++++++--------------- 2 files changed, 205 insertions(+), 256 deletions(-) diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index d18c82df4f..0ee32f820b 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -1064,6 +1064,41 @@ class LoraBaseMixin: save_function(state_dict, save_path) logger.info(f"Model weights saved in {save_path}") + @classmethod + def _save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]], + lora_metadata: Dict[str, Optional[dict]], + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + """ + Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all + pipeline types. + """ + state_dict = {} + final_lora_adapter_metadata = {} + + for prefix, layers in lora_layers.items(): + state_dict.update(cls.pack_weights(layers, prefix)) + + for prefix, metadata in lora_metadata.items(): + if metadata: + final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix)) + + cls.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None, + ) + @classmethod def _optionally_disable_offloading(cls, _pipeline): return _func_optionally_disable_offloading(_pipeline=_pipeline) diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 7e89066f1f..8060b519f1 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -510,35 +510,28 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): text_encoder_lora_adapter_metadata: LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not (unet_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") + lora_layers = {} + lora_metadata = {} if unet_lora_layers: - state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) + lora_layers[cls.unet_name] = unet_lora_layers + lora_metadata[cls.unet_name] = unet_lora_adapter_metadata if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + lora_layers[cls.text_encoder_name] = text_encoder_lora_layers + lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata - if unet_lora_adapter_metadata: - lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.") - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -1004,44 +997,34 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): text_encoder_2_lora_adapter_metadata: LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." - ) + lora_layers = {} + lora_metadata = {} if unet_lora_layers: - state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name)) + lora_layers[cls.unet_name] = unet_lora_layers + lora_metadata[cls.unet_name] = unet_lora_adapter_metadata if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + lora_layers["text_encoder"] = text_encoder_lora_layers + lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata if text_encoder_2_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + lora_layers["text_encoder_2"] = text_encoder_2_lora_layers + lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata - if unet_lora_adapter_metadata is not None: - lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name)) - - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) + if not lora_layers: + raise ValueError( + "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`." ) - if text_encoder_2_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") - ) - - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -1467,46 +1450,34 @@ class SD3LoraLoaderMixin(LoraBaseMixin): text_encoder_2_lora_adapter_metadata: LoRA adapter metadata associated with the second text encoder to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`." - ) + lora_layers = {} + lora_metadata = {} if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) + lora_layers["text_encoder"] = text_encoder_lora_layers + lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata if text_encoder_2_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + lora_layers["text_encoder_2"] = text_encoder_2_lora_layers + lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) + if not lora_layers: + raise ValueError( + "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`." ) - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) - ) - - if text_encoder_2_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2") - ) - - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer @@ -1830,28 +1801,24 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -2435,37 +2402,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin): text_encoder_lora_adapter_metadata: LoRA adapter metadata associated with the text encoder to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} - - if not (transformer_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") + lora_layers = {} + lora_metadata = {} if transformer_lora_layers: - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata if text_encoder_lora_layers: - state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) + lora_layers[cls.text_encoder_name] = text_encoder_lora_layers + lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata - if transformer_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if text_encoder_lora_adapter_metadata: - lora_adapter_metadata.update( - _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -3254,28 +3212,24 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) def fuse_lora( @@ -3594,28 +3548,24 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -3938,28 +3888,24 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4280,28 +4226,24 @@ class SanaLoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4624,28 +4566,24 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -4969,28 +4907,24 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -5384,28 +5318,24 @@ class WanLoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -5802,28 +5732,24 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -6144,28 +6070,24 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora @@ -6488,28 +6410,24 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora @@ -6835,28 +6753,24 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin): transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ - state_dict = {} - lora_adapter_metadata = {} + lora_layers = {} + lora_metadata = {} - if not transformer_lora_layers: - raise ValueError("You must pass `transformer_lora_layers`.") + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata - state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.") - if transformer_lora_adapter_metadata is not None: - lora_adapter_metadata.update( - _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) - ) - - # Save the model - cls.write_lora_layers( - state_dict=state_dict, + cls._save_lora_weights( save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, - lora_adapter_metadata=lora_adapter_metadata, ) # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora