mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[lora] factor out the overlaps in save_lora_weights(). (#12027)
* factor out the overlaps in save_lora_weights(). * remove comment. * remove comment. * up * fix-copies
This commit is contained in:
@@ -1064,6 +1064,41 @@ class LoraBaseMixin:
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@classmethod
|
||||
def _save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
|
||||
lora_metadata: Dict[str, Optional[dict]],
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
"""
|
||||
Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
|
||||
pipeline types.
|
||||
"""
|
||||
state_dict = {}
|
||||
final_lora_adapter_metadata = {}
|
||||
|
||||
for prefix, layers in lora_layers.items():
|
||||
state_dict.update(cls.pack_weights(layers, prefix))
|
||||
|
||||
for prefix, metadata in lora_metadata.items():
|
||||
if metadata:
|
||||
final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
|
||||
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@@ -510,35 +510,28 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
text_encoder_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers):
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if unet_lora_layers:
|
||||
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
|
||||
lora_layers[cls.unet_name] = unet_lora_layers
|
||||
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
||||
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
|
||||
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
|
||||
|
||||
if unet_lora_adapter_metadata:
|
||||
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if text_encoder_lora_adapter_metadata:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
def fuse_lora(
|
||||
@@ -1004,44 +997,34 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
text_encoder_2_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||
raise ValueError(
|
||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
|
||||
)
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if unet_lora_layers:
|
||||
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
|
||||
lora_layers[cls.unet_name] = unet_lora_layers
|
||||
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
lora_layers["text_encoder"] = text_encoder_lora_layers
|
||||
lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
|
||||
|
||||
if text_encoder_2_lora_layers:
|
||||
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
|
||||
lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
|
||||
|
||||
if unet_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
|
||||
|
||||
if text_encoder_lora_adapter_metadata:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
|
||||
if not lora_layers:
|
||||
raise ValueError(
|
||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
|
||||
)
|
||||
|
||||
if text_encoder_2_lora_adapter_metadata:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
|
||||
)
|
||||
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
def fuse_lora(
|
||||
@@ -1467,46 +1450,34 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
text_encoder_2_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
|
||||
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||
raise ValueError(
|
||||
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
|
||||
)
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
lora_layers["text_encoder"] = text_encoder_lora_layers
|
||||
lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
|
||||
|
||||
if text_encoder_2_lora_layers:
|
||||
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
|
||||
lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
if not lora_layers:
|
||||
raise ValueError(
|
||||
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
|
||||
)
|
||||
|
||||
if text_encoder_lora_adapter_metadata:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
|
||||
)
|
||||
|
||||
if text_encoder_2_lora_adapter_metadata:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
|
||||
)
|
||||
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
@@ -1830,28 +1801,24 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
||||
@@ -2435,37 +2402,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
text_encoder_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
|
||||
if not (transformer_lora_layers or text_encoder_lora_layers):
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
|
||||
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
|
||||
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
|
||||
|
||||
if transformer_lora_adapter_metadata:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if text_encoder_lora_adapter_metadata:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
def fuse_lora(
|
||||
@@ -3254,28 +3212,24 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
def fuse_lora(
|
||||
@@ -3594,28 +3548,24 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
@@ -3938,28 +3888,24 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
@@ -4280,28 +4226,24 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
@@ -4624,28 +4566,24 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
@@ -4969,28 +4907,24 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
||||
@@ -5384,28 +5318,24 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
@@ -5802,28 +5732,24 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
@@ -6144,28 +6070,24 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
@@ -6488,28 +6410,24 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
|
||||
@@ -6835,28 +6753,24 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
|
||||
"""
|
||||
state_dict = {}
|
||||
lora_adapter_metadata = {}
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
if transformer_lora_adapter_metadata is not None:
|
||||
lora_adapter_metadata.update(
|
||||
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
|
||||
)
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
lora_adapter_metadata=lora_adapter_metadata,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
|
||||
Reference in New Issue
Block a user