mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA serialization] fix: duplicate unet prefix problem. (#5991)
* fix: duplicate unet prefix problem. * Update src/diffusers/loaders/lora.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -391,6 +391,10 @@ class LoraLoaderMixin:
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
if all(key.startswith("unet.unet") for key in keys):
|
||||
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
|
||||
deprecate("unet.unet keys", "0.27", deprecation_message)
|
||||
|
||||
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
@@ -407,8 +411,9 @@ class LoraLoaderMixin:
|
||||
else:
|
||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warn(warn_message)
|
||||
if not USE_PEFT_BACKEND:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warn(warn_message)
|
||||
|
||||
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
@@ -800,29 +805,21 @@ class LoraLoaderMixin:
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
# Create a flat dictionary.
|
||||
state_dict = {}
|
||||
|
||||
# Populate the dictionary.
|
||||
if unet_lora_layers is not None:
|
||||
weights = (
|
||||
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
|
||||
)
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
|
||||
state_dict.update(unet_lora_state_dict)
|
||||
if not (unet_lora_layers or text_encoder_lora_layers):
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
|
||||
|
||||
if text_encoder_lora_layers is not None:
|
||||
weights = (
|
||||
text_encoder_lora_layers.state_dict()
|
||||
if isinstance(text_encoder_lora_layers, torch.nn.Module)
|
||||
else text_encoder_lora_layers
|
||||
)
|
||||
if unet_lora_layers:
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
|
||||
text_encoder_lora_state_dict = {
|
||||
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
|
||||
}
|
||||
state_dict.update(text_encoder_lora_state_dict)
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
|
||||
@@ -67,7 +67,7 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
||||
current_lora_layer_sd = lora_layer.state_dict()
|
||||
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
|
||||
# The matrix name can either be "down" or "up".
|
||||
lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param
|
||||
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
|
||||
|
||||
return lora_state_dict
|
||||
|
||||
|
||||
Reference in New Issue
Block a user