1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[lora] factor out the overlaps in save_lora_weights(). (#12027)

* factor out the overlaps in save_lora_weights().

* remove comment.

* remove comment.

* up

* fix-copies
This commit is contained in:
Sayak Paul
2025-09-22 15:14:39 +05:30
committed by GitHub
parent 78031c2938
commit d8310a8fca
2 changed files with 205 additions and 256 deletions

View File

@@ -1064,6 +1064,41 @@ class LoraBaseMixin:
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
@classmethod
def _save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
lora_metadata: Dict[str, Optional[dict]],
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
"""
Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
pipeline types.
"""
state_dict = {}
final_lora_adapter_metadata = {}
for prefix, layers in lora_layers.items():
state_dict.update(cls.pack_weights(layers, prefix))
for prefix, metadata in lora_metadata.items():
if metadata:
final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
)
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
return _func_optionally_disable_offloading(_pipeline=_pipeline)

View File

@@ -510,35 +510,28 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
lora_layers = {}
lora_metadata = {}
if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
lora_layers[cls.unet_name] = unet_lora_layers
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
if unet_lora_adapter_metadata:
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
if not lora_layers:
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -1004,44 +997,34 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_adapter_metadata:
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
)
lora_layers = {}
lora_metadata = {}
if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
lora_layers[cls.unet_name] = unet_lora_layers
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
lora_layers["text_encoder"] = text_encoder_lora_layers
lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
if unet_lora_adapter_metadata is not None:
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
if not lora_layers:
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
)
if text_encoder_2_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
)
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -1467,46 +1450,34 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_adapter_metadata:
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
)
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
lora_layers["text_encoder"] = text_encoder_lora_layers
lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
if not lora_layers:
raise ValueError(
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
)
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
if text_encoder_2_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
)
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
@@ -1830,28 +1801,24 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -2435,37 +2402,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
if transformer_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -3254,28 +3212,24 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -3594,28 +3548,24 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -3938,28 +3888,24 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4280,28 +4226,24 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4624,28 +4566,24 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4969,28 +4907,24 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -5384,28 +5318,24 @@ class WanLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -5802,28 +5732,24 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -6144,28 +6070,24 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -6488,28 +6410,24 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -6835,28 +6753,24 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
lora_layers = {}
lora_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora