1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] feat: save_lora_adapter() (#9862)

* feat: save_lora_adapter.
This commit is contained in:
Sayak Paul
2024-11-19 12:33:38 +05:30
committed by GitHub
parent acf479bded
commit 7d0b9c4d4e
6 changed files with 210 additions and 55 deletions

View File

@@ -298,8 +298,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
if not only_text_encoder:
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
@@ -827,8 +828,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
if not only_text_encoder:
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
unet.load_lora_adapter(
state_dict,
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,

View File

@@ -13,9 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Union
import safetensors
import torch
import torch.nn as nn
from ..utils import (
@@ -189,40 +193,45 @@ class PeftAdapterMixin:
user_agent=user_agent,
allow_pickle=allow_pickle,
)
if network_alphas is not None and prefix is None:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(prefix)]
if len(transformer_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys}
if prefix is not None:
keys = list(state_dict.keys())
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
if len(model_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
)
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(self, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
@@ -276,6 +285,69 @@ class PeftAdapterMixin:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def save_lora_adapter(
self,
save_directory,
adapter_name: str = "default",
upcast_before_saving: bool = False,
safe_serialization: bool = True,
weight_name: Optional[str] = None,
):
"""
Save the LoRA parameters corresponding to the underlying model.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the
underlying model has multiple adapters loaded.
upcast_before_saving (`bool`, defaults to `False`):
Whether to cast the underlying model to `torch.float32` before serialization.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
"""
from peft.utils import get_peft_model_state_dict
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if adapter_name is None:
adapter_name = get_adapter_name(self)
if adapter_name not in getattr(self, "peft_config", {}):
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
lora_layers_to_save = get_peft_model_state_dict(
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
)
if os.path.isfile(save_directory):
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
if weight_name is None:
if safe_serialization:
weight_name = LORA_WEIGHT_NAME_SAFE
else:
weight_name = LORA_WEIGHT_NAME
# TODO: we could consider saving the `peft_config` as well.
save_path = Path(save_directory, weight_name).as_posix()
save_function(lora_layers_to_save, save_path)
logger.info(f"Model weights saved in {save_path}")
def set_adapters(
self,
adapter_names: Union[List[str], str],

View File

@@ -36,6 +36,7 @@ from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_unet_state_dict_to_peft,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
@@ -209,6 +210,10 @@ class UNet2DConditionLoadersMixin:
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if is_lora:
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
deprecate("load_attn_procs", "0.40.0", deprecation_message)
if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora:

View File

@@ -1784,11 +1784,7 @@ class PeftLoraLoaderMixinTests:
missing_key = [k for k in state_dict if "lora_A" in k][0]
del state_dict[missing_key]
logger = (
logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.peft")
)
logger = logging.get_logger("diffusers.loaders.peft")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
@@ -1823,11 +1819,7 @@ class PeftLoraLoaderMixinTests:
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
logger = (
logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.peft")
)
logger = logging.get_logger("diffusers.loaders.peft")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)

View File

@@ -44,6 +44,7 @@ from diffusers.training_utils import EMAModel
from diffusers.utils import (
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_INDEX_NAME,
is_peft_available,
is_torch_npu_available,
is_xformers_available,
logging,
@@ -65,6 +66,10 @@ from diffusers.utils.testing_utils import (
from ..others.test_utils import TOKEN, USER, is_staging_test
if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer
def caculate_expected_num_shards(index_map_path):
with open(index_map_path) as f:
weight_map_dict = json.load(f)["weight_map"]
@@ -74,6 +79,16 @@ def caculate_expected_num_shards(index_map_path):
return expected_num_shards
def check_if_lora_correctly_set(model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
# Will be run via run_test_in_subprocess
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
error = None
@@ -877,8 +892,6 @@ class ModelTesterMixin:
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}")
assert set(modules_with_gc_enabled.keys()) == expected_set
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
@@ -902,6 +915,94 @@ class ModelTesterMixin:
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
@parameterized.expand([True, False])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_save_load_lora_adapter(self, use_dora=False):
import safetensors
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
return
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
self.assertTrue(torch.allclose(loaded_v, retrieved_v))
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_wrong_adapter_name_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
return
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
wrong_name = "foo"
with self.assertRaises(ValueError) as err_context:
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
@require_torch_gpu
def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()

View File

@@ -1078,30 +1078,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample = model(**inputs_dict).sample
assert not torch.allclose(
non_lora_sample, lora_sample, atol=1e-4, rtol=1e-4
), "LoRA injected UNet should produce different results."
@require_peft_backend
def test_lora_serialization(self):
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
@@ -1122,8 +1099,14 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
with self.assertWarns(FutureWarning) as warning:
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
warning_message = str(warning.warnings[0].message)
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
# import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():