mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] parse metadata from LoRA and save metadata (#11324)
* feat: parse metadata from lora state dicts. * tests * fix tests * key renaming * fix * smol update * smol updates * load metadata. * automatically save metadata in save_lora_adapter. * propagate changes. * changes * add test to models too. * tigher tests. * updates * fixes * rename tests. * sorted. * Update src/diffusers/loaders/lora_base.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * review suggestions. * removeprefix. * propagate changes. * fix-copies * sd * docs. * fixes * get review ready. * one more test to catch error. * change to a different approach. * fix-copies. * todo * sd3 * update * revert changes in get_peft_kwargs. * update * fixes * fixes * simplify _load_sft_state_dict_metadata * update * style fix * uipdate * update * update * empty commit * _pack_dict_with_prefix * update * TODO 1. * todo: 2. * todo: 3. * update * update * Apply suggestions from code review Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * reraise. * move argument. --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
@@ -282,10 +282,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
|
||||
@@ -159,10 +159,7 @@ class IPAdapterMixin:
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
@@ -465,10 +462,7 @@ class FluxIPAdapterMixin:
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
@@ -750,10 +744,7 @@ class SD3IPAdapterMixin:
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
@@ -45,6 +46,7 @@ from ..utils import (
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -62,6 +64,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
|
||||
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
@@ -206,6 +209,7 @@ def _fetch_state_dict(
|
||||
subfolder,
|
||||
user_agent,
|
||||
allow_pickle,
|
||||
metadata=None,
|
||||
):
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
@@ -236,11 +240,14 @@ def _fetch_state_dict(
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
metadata = _load_sft_state_dict_metadata(model_file)
|
||||
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
metadata = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
@@ -261,10 +268,11 @@ def _fetch_state_dict(
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = load_state_dict(model_file)
|
||||
metadata = None
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
return state_dict
|
||||
return state_dict, metadata
|
||||
|
||||
|
||||
def _best_guess_weight_name(
|
||||
@@ -306,6 +314,11 @@ def _best_guess_weight_name(
|
||||
return weight_name
|
||||
|
||||
|
||||
def _pack_dict_with_prefix(state_dict, prefix):
|
||||
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
|
||||
return sd_with_prefix
|
||||
|
||||
|
||||
def _load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if network_alphas and metadata:
|
||||
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
if metadata is not None:
|
||||
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
if metadata is not None:
|
||||
lora_config_kwargs = metadata
|
||||
else:
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
try:
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError("`LoraConfig` class could not be instantiated.") from e
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
@@ -889,8 +914,7 @@ class LoraBaseMixin:
|
||||
@staticmethod
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
return _pack_dict_with_prefix(layers_weights, prefix)
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
@@ -900,16 +924,32 @@ class LoraBaseMixin:
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if lora_adapter_metadata and not safe_serialization:
|
||||
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
|
||||
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
|
||||
raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
# Inject framework format.
|
||||
metadata = {"format": "pt"}
|
||||
if lora_adapter_metadata:
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
|
||||
lora_adapter_metadata, indent=2, sort_keys=True
|
||||
)
|
||||
|
||||
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@@ -185,6 +186,7 @@ class PeftAdapterMixin:
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
metadata: TODO
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
@@ -202,6 +204,7 @@ class PeftAdapterMixin:
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
_pipeline = kwargs.pop("_pipeline", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
|
||||
metadata = kwargs.pop("metadata", None)
|
||||
allow_pickle = False
|
||||
|
||||
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
|
||||
@@ -209,12 +212,9 @@ class PeftAdapterMixin:
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
@@ -227,12 +227,17 @@ class PeftAdapterMixin:
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
metadata=metadata,
|
||||
)
|
||||
if network_alphas is not None and prefix is None:
|
||||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
||||
if network_alphas and metadata:
|
||||
raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
|
||||
|
||||
if prefix is not None:
|
||||
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
if metadata is not None:
|
||||
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
|
||||
@@ -267,7 +272,12 @@ class PeftAdapterMixin:
|
||||
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
if metadata is not None:
|
||||
lora_config_kwargs = metadata
|
||||
else:
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
|
||||
)
|
||||
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
@@ -290,7 +300,11 @@ class PeftAdapterMixin:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
try:
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
except TypeError as e:
|
||||
raise TypeError("`LoraConfig` class could not be instantiated.") from e
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
@@ -445,17 +459,13 @@ class PeftAdapterMixin:
|
||||
underlying model has multiple adapters loaded.
|
||||
upcast_before_saving (`bool`, defaults to `False`):
|
||||
Whether to cast the underlying model to `torch.float32` before serialization.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
|
||||
"""
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
@@ -463,6 +473,8 @@ class PeftAdapterMixin:
|
||||
if adapter_name not in getattr(self, "peft_config", {}):
|
||||
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
|
||||
|
||||
lora_adapter_metadata = self.peft_config[adapter_name].to_dict()
|
||||
|
||||
lora_layers_to_save = get_peft_model_state_dict(
|
||||
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
|
||||
)
|
||||
@@ -472,7 +484,15 @@ class PeftAdapterMixin:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
# Inject framework format.
|
||||
metadata = {"format": "pt"}
|
||||
if lora_adapter_metadata is not None:
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
|
||||
return safetensors.torch.save_file(weights, filename, metadata=metadata)
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
@@ -485,7 +505,6 @@ class PeftAdapterMixin:
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
# TODO: we could consider saving the `peft_config` as well.
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(lora_layers_to_save, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@@ -155,10 +155,7 @@ class UNet2DConditionLoadersMixin:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
|
||||
@@ -16,6 +16,7 @@ State dict utilities: utility methods for converting state dicts easily
|
||||
"""
|
||||
|
||||
import enum
|
||||
import json
|
||||
|
||||
from .import_utils import is_torch_available
|
||||
from .logging import get_logger
|
||||
@@ -347,3 +348,16 @@ def state_dict_all_zero(state_dict, filter_str=None):
|
||||
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
|
||||
|
||||
return all(torch.all(param == 0).item() for param in state_dict.values())
|
||||
|
||||
|
||||
def _load_sft_state_dict_metadata(model_file: str):
|
||||
import safetensors.torch
|
||||
|
||||
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
return json.loads(raw) if raw else None
|
||||
|
||||
@@ -133,6 +133,29 @@ def numpy_cosine_similarity_distance(a, b):
|
||||
return distance
|
||||
|
||||
|
||||
def check_if_dicts_are_equal(dict1, dict2):
|
||||
dict1, dict2 = dict1.copy(), dict2.copy()
|
||||
|
||||
for key, value in dict1.items():
|
||||
if isinstance(value, set):
|
||||
dict1[key] = sorted(value)
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, set):
|
||||
dict2[key] = sorted(value)
|
||||
|
||||
for key in dict1:
|
||||
if key not in dict2:
|
||||
return False
|
||||
if dict1[key] != dict2[key]:
|
||||
return False
|
||||
|
||||
for key in dict2:
|
||||
if key not in dict1:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def print_tensor_test(
|
||||
tensor,
|
||||
limit_to_slices=None,
|
||||
|
||||
@@ -24,11 +24,7 @@ from diffusers import (
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
@@ -22,6 +22,7 @@ from itertools import product
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -33,6 +34,7 @@ from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
CaptureLogger,
|
||||
check_if_dicts_are_equal,
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
@@ -71,6 +73,13 @@ def check_if_lora_correctly_set(model) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str):
|
||||
extracted = {
|
||||
k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.")
|
||||
}
|
||||
check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"])
|
||||
|
||||
|
||||
def initialize_dummy_state_dict(state_dict):
|
||||
if not all(v.device.type == "meta" for _, v in state_dict.items()):
|
||||
raise ValueError("`state_dict` has non-meta values.")
|
||||
@@ -118,7 +127,7 @@ class PeftLoraLoaderMixinTests:
|
||||
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
|
||||
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
|
||||
if self.unet_kwargs and self.transformer_kwargs:
|
||||
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
|
||||
if self.has_two_text_encoders and self.has_three_text_encoders:
|
||||
@@ -126,6 +135,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
|
||||
rank = 4
|
||||
lora_alpha = rank if lora_alpha is None else lora_alpha
|
||||
|
||||
torch.manual_seed(0)
|
||||
if self.unet_kwargs is not None:
|
||||
@@ -161,7 +171,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
text_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=self.text_encoder_target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
@@ -169,7 +179,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=self.denoiser_target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
@@ -246,6 +256,13 @@ class PeftLoraLoaderMixinTests:
|
||||
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
|
||||
return state_dicts
|
||||
|
||||
def _get_lora_adapter_metadata(self, modules_to_save):
|
||||
metadatas = {}
|
||||
for module_name, module in modules_to_save.items():
|
||||
if module is not None:
|
||||
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
|
||||
return metadatas
|
||||
|
||||
def _get_modules_to_save(self, pipe, has_denoiser=False):
|
||||
modules_to_save = {}
|
||||
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
|
||||
@@ -2214,6 +2231,86 @@ class PeftLoraLoaderMixinTests:
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@parameterized.expand([4, 8, 16])
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
|
||||
scheduler_cls, lora_alpha=lora_alpha
|
||||
)
|
||||
pipe = self.pipeline_class(**components)
|
||||
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True)
|
||||
if len(out) == 3:
|
||||
_, _, parsed_metadata = out
|
||||
elif len(out) == 2:
|
||||
_, parsed_metadata = out
|
||||
|
||||
denoiser_key = (
|
||||
f"{self.pipeline_class.transformer_name}"
|
||||
if self.transformer_kwargs is not None
|
||||
else f"{self.pipeline_class.unet_name}"
|
||||
)
|
||||
self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
|
||||
check_module_lora_metadata(
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
|
||||
)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
text_encoder_key = self.pipeline_class.text_encoder_name
|
||||
self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
|
||||
check_module_lora_metadata(
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
|
||||
)
|
||||
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
text_encoder_2_key = "text_encoder_2"
|
||||
self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
|
||||
check_module_lora_metadata(
|
||||
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
|
||||
)
|
||||
|
||||
@parameterized.expand([4, 8, 16])
|
||||
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
|
||||
scheduler_cls, lora_alpha=lora_alpha
|
||||
)
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
|
||||
)
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
|
||||
)
|
||||
|
||||
def test_inference_load_delete_load_adapters(self):
|
||||
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
|
||||
@@ -30,6 +30,7 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import requests_mock
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||
@@ -62,6 +63,7 @@ from diffusers.utils.testing_utils import (
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
backend_synchronize,
|
||||
check_if_dicts_are_equal,
|
||||
get_python_version,
|
||||
is_torch_compile,
|
||||
numpy_cosine_similarity_distance,
|
||||
@@ -1057,11 +1059,10 @@ class ModelTesterMixin:
|
||||
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
@parameterized.expand([True, False])
|
||||
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_save_load_adapter(self, use_dora=False):
|
||||
import safetensors
|
||||
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
@@ -1077,8 +1078,8 @@ class ModelTesterMixin:
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
@@ -1145,6 +1146,90 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
|
||||
|
||||
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
return
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(model_file))
|
||||
|
||||
model.unload_lora()
|
||||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
return
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(model_file))
|
||||
|
||||
# Perturb the metadata in the state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
with self.assertRaises(TypeError) as err_context:
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_cpu_offload(self):
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user