1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] parse metadata from LoRA and save metadata (#11324)

* feat: parse metadata from lora state dicts.

* tests

* fix tests

* key renaming

* fix

* smol update

* smol updates

* load metadata.

* automatically save metadata in save_lora_adapter.

* propagate changes.

* changes

* add test to models too.

* tigher tests.

* updates

* fixes

* rename tests.

* sorted.

* Update src/diffusers/loaders/lora_base.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* review suggestions.

* removeprefix.

* propagate changes.

* fix-copies

* sd

* docs.

* fixes

* get review ready.

* one more test to catch error.

* change to a different approach.

* fix-copies.

* todo

* sd3

* update

* revert changes in get_peft_kwargs.

* update

* fixes

* fixes

* simplify _load_sft_state_dict_metadata

* update

* style fix

* uipdate

* update

* update

* empty commit

* _pack_dict_with_prefix

* update

* TODO 1.

* todo: 2.

* todo: 3.

* update

* update

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* reraise.

* move argument.

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
Sayak Paul
2025-06-13 14:37:49 +05:30
committed by GitHub
parent e52ceae375
commit 368958df6f
11 changed files with 845 additions and 199 deletions

View File

@@ -282,10 +282,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,

View File

@@ -159,10 +159,7 @@ class IPAdapterMixin:
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
@@ -465,10 +462,7 @@ class FluxIPAdapterMixin:
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
@@ -750,10 +744,7 @@ class SD3IPAdapterMixin:
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(

View File

@@ -14,6 +14,7 @@
import copy
import inspect
import json
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
@@ -45,6 +46,7 @@ from ..utils import (
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
if is_transformers_available():
@@ -62,6 +64,7 @@ logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
@@ -206,6 +209,7 @@ def _fetch_state_dict(
subfolder,
user_agent,
allow_pickle,
metadata=None,
):
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
@@ -236,11 +240,14 @@ def _fetch_state_dict(
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
metadata = _load_sft_state_dict_metadata(model_file)
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
metadata = None
pass
if model_file is None:
@@ -261,10 +268,11 @@ def _fetch_state_dict(
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
metadata = None
else:
state_dict = pretrained_model_name_or_path_or_dict
return state_dict
return state_dict, metadata
def _best_guess_weight_name(
@@ -306,6 +314,11 @@ def _best_guess_weight_name(
return weight_name
def _pack_dict_with_prefix(state_dict, prefix):
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
return sd_with_prefix
def _load_lora_into_text_encoder(
state_dict,
network_alphas,
@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if network_alphas and metadata:
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if metadata is not None:
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
logger.info(f"Loading {prefix}.")
@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# adapter_name
if adapter_name is None:
@@ -889,8 +914,7 @@ class LoraBaseMixin:
@staticmethod
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
return _pack_dict_with_prefix(layers_weights, prefix)
@staticmethod
def write_lora_layers(
@@ -900,16 +924,32 @@ class LoraBaseMixin:
weight_name: str,
save_function: Callable,
safe_serialization: bool,
lora_adapter_metadata: Optional[dict] = None,
):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if lora_adapter_metadata and not safe_serialization:
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
# Inject framework format.
metadata = {"format": "pt"}
if lora_adapter_metadata:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
lora_adapter_metadata, indent=2, sort_keys=True
)
return safetensors.torch.save_file(weights, filename, metadata=metadata)
else:
save_function = torch.save

File diff suppressed because it is too large Load Diff

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
from functools import partial
from pathlib import Path
@@ -185,6 +186,7 @@ class PeftAdapterMixin:
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
metadata: TODO
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -202,6 +204,7 @@ class PeftAdapterMixin:
network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
metadata = kwargs.pop("metadata", None)
allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
@@ -209,12 +212,9 @@ class PeftAdapterMixin:
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
@@ -227,12 +227,17 @@ class PeftAdapterMixin:
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
metadata=metadata,
)
if network_alphas is not None and prefix is None:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
if network_alphas and metadata:
raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
if prefix is not None:
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if metadata is not None:
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
@@ -267,7 +272,12 @@ class PeftAdapterMixin:
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
if "use_dora" in lora_config_kwargs:
@@ -290,7 +300,11 @@ class PeftAdapterMixin:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
@@ -445,17 +459,13 @@ class PeftAdapterMixin:
underlying model has multiple adapters loaded.
upcast_before_saving (`bool`, defaults to `False`):
Whether to cast the underlying model to `torch.float32` before serialization.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
"""
from peft.utils import get_peft_model_state_dict
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if adapter_name is None:
adapter_name = get_adapter_name(self)
@@ -463,6 +473,8 @@ class PeftAdapterMixin:
if adapter_name not in getattr(self, "peft_config", {}):
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
lora_adapter_metadata = self.peft_config[adapter_name].to_dict()
lora_layers_to_save = get_peft_model_state_dict(
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
)
@@ -472,7 +484,15 @@ class PeftAdapterMixin:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
# Inject framework format.
metadata = {"format": "pt"}
if lora_adapter_metadata is not None:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
return safetensors.torch.save_file(weights, filename, metadata=metadata)
else:
save_function = torch.save
@@ -485,7 +505,6 @@ class PeftAdapterMixin:
else:
weight_name = LORA_WEIGHT_NAME
# TODO: we could consider saving the `peft_config` as well.
save_path = Path(save_directory, weight_name).as_posix()
save_function(lora_layers_to_save, save_path)
logger.info(f"Model weights saved in {save_path}")

View File

@@ -155,10 +155,7 @@ class UNet2DConditionLoadersMixin:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):

View File

@@ -16,6 +16,7 @@ State dict utilities: utility methods for converting state dicts easily
"""
import enum
import json
from .import_utils import is_torch_available
from .logging import get_logger
@@ -347,3 +348,16 @@ def state_dict_all_zero(state_dict, filter_str=None):
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
return all(torch.all(param == 0).item() for param in state_dict.values())
def _load_sft_state_dict_metadata(model_file: str):
import safetensors.torch
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None

View File

@@ -133,6 +133,29 @@ def numpy_cosine_similarity_distance(a, b):
return distance
def check_if_dicts_are_equal(dict1, dict2):
dict1, dict2 = dict1.copy(), dict2.copy()
for key, value in dict1.items():
if isinstance(value, set):
dict1[key] = sorted(value)
for key, value in dict2.items():
if isinstance(value, set):
dict2[key] = sorted(value)
for key in dict1:
if key not in dict2:
return False
if dict1[key] != dict2[key]:
return False
for key in dict2:
if key not in dict1:
return False
return True
def print_tensor_test(
tensor,
limit_to_slices=None,

View File

@@ -24,11 +24,7 @@ from diffusers import (
WanPipeline,
WanTransformer3DModel,
)
from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
skip_mps,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
sys.path.append(".")

View File

@@ -22,6 +22,7 @@ from itertools import product
import numpy as np
import pytest
import torch
from parameterized import parameterized
from diffusers import (
AutoencoderKL,
@@ -33,6 +34,7 @@ from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
CaptureLogger,
check_if_dicts_are_equal,
floats_tensor,
is_torch_version,
require_peft_backend,
@@ -71,6 +73,13 @@ def check_if_lora_correctly_set(model) -> bool:
return False
def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str):
extracted = {
k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.")
}
check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"])
def initialize_dummy_state_dict(state_dict):
if not all(v.device.type == "meta" for _, v in state_dict.items()):
raise ValueError("`state_dict` has non-meta values.")
@@ -118,7 +127,7 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders:
@@ -126,6 +135,7 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha
torch.manual_seed(0)
if self.unet_kwargs is not None:
@@ -161,7 +171,7 @@ class PeftLoraLoaderMixinTests:
text_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
lora_alpha=lora_alpha,
target_modules=self.text_encoder_target_modules,
init_lora_weights=False,
use_dora=use_dora,
@@ -169,7 +179,7 @@ class PeftLoraLoaderMixinTests:
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
lora_alpha=lora_alpha,
target_modules=self.denoiser_target_modules,
init_lora_weights=False,
use_dora=use_dora,
@@ -246,6 +256,13 @@ class PeftLoraLoaderMixinTests:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
return state_dicts
def _get_lora_adapter_metadata(self, modules_to_save):
metadatas = {}
for module_name, module in modules_to_save.items():
if module is not None:
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
return metadatas
def _get_modules_to_save(self, pipe, has_denoiser=False):
modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
@@ -2214,6 +2231,86 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0]
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components)
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True)
if len(out) == 3:
_, _, parsed_metadata = out
elif len(out) == 2:
_, parsed_metadata = out
denoiser_key = (
f"{self.pipeline_class.transformer_name}"
if self.transformer_kwargs is not None
else f"{self.pipeline_class.unet_name}"
)
self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
text_encoder_key = self.pipeline_class.text_encoder_name
self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
)
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_key = "text_encoder_2"
self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
)
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
)
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:

View File

@@ -30,6 +30,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import requests_mock
import safetensors.torch
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
@@ -62,6 +63,7 @@ from diffusers.utils.testing_utils import (
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
backend_synchronize,
check_if_dicts_are_equal,
get_python_version,
is_torch_compile,
numpy_cosine_similarity_distance,
@@ -1057,11 +1059,10 @@ class ModelTesterMixin:
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
@parameterized.expand([True, False])
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_save_load_adapter(self, use_dora=False):
import safetensors
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
@@ -1077,8 +1078,8 @@ class ModelTesterMixin:
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
@@ -1145,6 +1146,90 @@ class ModelTesterMixin:
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
from peft import LoraConfig
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
return
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_wrong_metadata_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
return
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
# Perturb the metadata in the state dict.
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with self.assertRaises(TypeError) as err_context:
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
@require_torch_accelerator
def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()