mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -298,8 +298,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
if not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_attn_procs(
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
@@ -827,8 +828,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
if not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_attn_procs(
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
|
||||
@@ -13,9 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import (
|
||||
@@ -189,40 +193,45 @@ class PeftAdapterMixin:
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
if network_alphas is not None and prefix is None:
|
||||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
transformer_keys = [k for k in keys if k.startswith(prefix)]
|
||||
if len(transformer_keys) > 0:
|
||||
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys}
|
||||
if prefix is not None:
|
||||
keys = list(state_dict.keys())
|
||||
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
|
||||
if len(model_keys) > 0:
|
||||
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
|
||||
)
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
if "lora_A" not in first_key:
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -276,6 +285,69 @@ class PeftAdapterMixin:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
def save_lora_adapter(
|
||||
self,
|
||||
save_directory,
|
||||
adapter_name: str = "default",
|
||||
upcast_before_saving: bool = False,
|
||||
safe_serialization: bool = True,
|
||||
weight_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Save the LoRA parameters corresponding to the underlying model.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the
|
||||
underlying model has multiple adapters loaded.
|
||||
upcast_before_saving (`bool`, defaults to `False`):
|
||||
Whether to cast the underlying model to `torch.float32` before serialization.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
|
||||
"""
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
|
||||
if adapter_name not in getattr(self, "peft_config", {}):
|
||||
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
|
||||
|
||||
lora_layers_to_save = get_peft_model_state_dict(
|
||||
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
|
||||
)
|
||||
if os.path.isfile(save_directory):
|
||||
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
# TODO: we could consider saving the `peft_config` as well.
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(lora_layers_to_save, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
|
||||
@@ -36,6 +36,7 @@ from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_unet_state_dict_to_peft,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
@@ -209,6 +210,10 @@ class UNet2DConditionLoadersMixin:
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if is_lora:
|
||||
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
|
||||
deprecate("load_attn_procs", "0.40.0", deprecation_message)
|
||||
|
||||
if is_custom_diffusion:
|
||||
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
||||
elif is_lora:
|
||||
|
||||
@@ -1784,11 +1784,7 @@ class PeftLoraLoaderMixinTests:
|
||||
missing_key = [k for k in state_dict if "lora_A" in k][0]
|
||||
del state_dict[missing_key]
|
||||
|
||||
logger = (
|
||||
logging.get_logger("diffusers.loaders.unet")
|
||||
if self.unet_kwargs is not None
|
||||
else logging.get_logger("diffusers.loaders.peft")
|
||||
)
|
||||
logger = logging.get_logger("diffusers.loaders.peft")
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(state_dict)
|
||||
@@ -1823,11 +1819,7 @@ class PeftLoraLoaderMixinTests:
|
||||
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
|
||||
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
|
||||
|
||||
logger = (
|
||||
logging.get_logger("diffusers.loaders.unet")
|
||||
if self.unet_kwargs is not None
|
||||
else logging.get_logger("diffusers.loaders.peft")
|
||||
)
|
||||
logger = logging.get_logger("diffusers.loaders.peft")
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(state_dict)
|
||||
|
||||
@@ -44,6 +44,7 @@ from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
is_peft_available,
|
||||
is_torch_npu_available,
|
||||
is_xformers_available,
|
||||
logging,
|
||||
@@ -65,6 +66,10 @@ from diffusers.utils.testing_utils import (
|
||||
from ..others.test_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
|
||||
def caculate_expected_num_shards(index_map_path):
|
||||
with open(index_map_path) as f:
|
||||
weight_map_dict = json.load(f)["weight_map"]
|
||||
@@ -74,6 +79,16 @@ def caculate_expected_num_shards(index_map_path):
|
||||
return expected_num_shards
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Checks if the LoRA layers are correctly set with peft
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
@@ -877,8 +892,6 @@ class ModelTesterMixin:
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}")
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == expected_set
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
|
||||
|
||||
@@ -902,6 +915,94 @@ class ModelTesterMixin:
|
||||
" from `_deprecated_kwargs = [<deprecated_argument>]`"
|
||||
)
|
||||
|
||||
@parameterized.expand([True, False])
|
||||
@torch.no_grad()
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_save_load_lora_adapter(self, use_dora=False):
|
||||
import safetensors
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
return
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
model.save_lora_adapter(tmpdir)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
model.unload_lora()
|
||||
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
|
||||
for k in state_dict_loaded:
|
||||
loaded_v = state_dict_loaded[k]
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
self.assertTrue(torch.allclose(loaded_v, retrieved_v))
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_wrong_adapter_name_raises_error(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not issubclass(model.__class__, PeftAdapterMixin):
|
||||
return
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
wrong_name = "foo"
|
||||
with self.assertRaises(ValueError) as err_context:
|
||||
model.save_lora_adapter(tmpdir, adapter_name=wrong_name)
|
||||
|
||||
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_cpu_offload(self):
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -1078,30 +1078,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(
|
||||
non_lora_sample, lora_sample, atol=1e-4, rtol=1e-4
|
||||
), "LoRA injected UNet should produce different results."
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora_serialization(self):
|
||||
def test_load_attn_procs_raise_warning(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
@@ -1122,8 +1099,14 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
|
||||
|
||||
# import to still check for the rest of the stuff.
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user