1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] Remove legacy LoRA code and related adjustments (#8316)

* remove legacy code from load_attn_procs.

* finish first draft

* fix more.

* fix more

* add test

* add serialization support.

* fix-copies

* require peft backend for lora tests

* style

* fix test

* fix loading.

* empty

* address benjamin's feedback.
This commit is contained in:
Sayak Paul
2024-06-05 08:15:30 +04:00
committed by GitHub
parent a8ad6664c2
commit a0542c1917
7 changed files with 399 additions and 432 deletions

View File

@@ -111,3 +111,21 @@ jobs:
-s -v \
--make-reports=tests_${{ matrix.config.report }} \
tests/lora/
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_models_lora_${{ matrix.config.report }} \
tests/models/ -k "lora"
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_${{ matrix.config.report }}_failures_short.txt
cat reports/tests_models_lora_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports

View File

@@ -189,12 +189,17 @@ jobs:
-s -v -k "not Flax and not Onnx and not PEFTLoRALoading" \
--make-reports=tests_peft_cuda \
tests/lora/
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "lora and not Flax and not Onnx and not PEFTLoRALoading" \
--make-reports=tests_peft_cuda_models_lora \
tests/models/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_peft_cuda_stats.txt
cat reports/tests_peft_cuda_failures_short.txt
cat reports/tests_peft_cuda_models_lora_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}

View File

@@ -22,17 +22,14 @@ import torch
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version
from torch import nn
from .. import __version__
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..models.modeling_utils import load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
@@ -119,13 +116,10 @@ class LoraLoaderMixin:
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
_pipeline=self,
)
@@ -136,7 +130,6 @@ class LoraLoaderMixin:
if not hasattr(self, "text_encoder")
else self.text_encoder,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
_pipeline=self,
)
@@ -193,16 +186,8 @@ class LoraLoaderMixin:
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
@@ -383,9 +368,7 @@ class LoraLoaderMixin:
return (is_model_cpu_offload, is_sequential_cpu_offload)
@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -395,14 +378,11 @@ class LoraLoaderMixin:
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details.
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
@@ -410,94 +390,18 @@ class LoraLoaderMixin:
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder:
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
network_alphas = {
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
else:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
if not USE_PEFT_BACKEND:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warning(warn_message)
if len(state_dict.keys()) > 0:
if adapter_name in getattr(unet, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
)
state_dict = convert_unet_state_dict_to_peft(state_dict)
if network_alphas is not None:
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
# `convert_unet_state_dict_to_peft` method.
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(unet)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
)
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
)
@classmethod
def load_lora_into_text_encoder(
@@ -507,7 +411,6 @@ class LoraLoaderMixin:
text_encoder,
prefix=None,
lora_scale=1.0,
low_cpu_mem_usage=None,
adapter_name=None,
_pipeline=None,
):
@@ -527,11 +430,6 @@ class LoraLoaderMixin:
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
@@ -541,8 +439,6 @@ class LoraLoaderMixin:
from peft import LoraConfig
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
@@ -625,9 +521,7 @@ class LoraLoaderMixin:
# Unsafe code />
@classmethod
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
@@ -640,19 +534,12 @@ class LoraLoaderMixin:
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
@@ -846,22 +733,11 @@ class LoraLoaderMixin:
>>> ...
```
"""
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if not USE_PEFT_BACKEND:
if version.parse(__version__) > version.parse("0.23"):
logger.warning(
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
)
raise ValueError("PEFT backend is required for this method.")
for _, module in unet.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
else:
recurse_remove_peft_layers(unet)
if hasattr(unet, "peft_config"):
del unet.peft_config
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.unload_lora()
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()

View File

@@ -33,34 +33,32 @@ from ..models.embeddings import (
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
convert_unet_state_dict_to_peft,
delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_version,
is_torch_version,
logging,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .unet_loader_utils import _maybe_expand_lora_scales
from .utils import AttnProcsLayers
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
@@ -79,7 +77,8 @@ class UNet2DConditionLoadersMixin:
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
defined in
[`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
and be a `torch.nn.Module` class.
and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install
`peft`: `pip install -U peft`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -110,20 +109,20 @@ class UNet2DConditionLoadersMixin:
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if youre downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
adapter_name (`str`, *optional*, defaults to None):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
Example:
@@ -139,9 +138,6 @@ class UNet2DConditionLoadersMixin:
)
```
"""
from ..models.attention_processor import CustomDiffusionAttnProcessor
from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", None)
@@ -152,15 +148,9 @@ class UNet2DConditionLoadersMixin:
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None)
adapter_name = kwargs.pop("adapter_name", None)
_pipeline = kwargs.pop("_pipeline", None)
is_network_alphas_none = network_alphas is None
network_alphas = kwargs.pop("network_alphas", None)
allow_pickle = False
if use_safetensors is None:
@@ -216,198 +206,196 @@ class UNet2DConditionLoadersMixin:
else:
state_dict = pretrained_model_name_or_path_or_dict
# fill attn processors
lora_layers_list = []
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if is_lora:
# correct keys
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
if network_alphas is not None:
network_alphas_keys = list(network_alphas.keys())
used_network_alphas_keys = set()
lora_grouped_dict = defaultdict(dict)
mapped_network_alphas = {}
all_keys = list(state_dict.keys())
for key in all_keys:
value = state_dict.pop(key)
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
if network_alphas is not None:
for k in network_alphas_keys:
if k.replace(".alpha", "") in key:
mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
used_network_alphas_keys.add(k)
if not is_network_alphas_none:
if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)
if len(state_dict) > 0:
raise ValueError(
f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
)
for key, value_dict in lora_grouped_dict.items():
attn_processor = self
for sub_key in key.split("."):
attn_processor = getattr(attn_processor, sub_key)
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
# or add_{k,v,q,out_proj}_proj_lora layers.
rank = value_dict["lora.down.weight"].shape[0]
if isinstance(attn_processor, LoRACompatibleConv):
in_features = attn_processor.in_channels
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
elif isinstance(attn_processor, LoRACompatibleLinear):
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora_layers_list.append((attn_processor, lora))
if low_cpu_mem_usage:
device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
else:
lora.load_state_dict(value_dict)
elif is_custom_diffusion:
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
if len(value) == 0:
custom_diffusion_grouped_dict[key] = {}
else:
if "to_out" in key:
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
else:
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
for key, value_dict in custom_diffusion_grouped_dict.items():
if len(value_dict) == 0:
attn_processors[key] = CustomDiffusionAttnProcessor(
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
)
else:
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
attn_processors[key] = CustomDiffusionAttnProcessor(
train_kv=True,
train_q_out=train_q_out,
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)
attn_processors[key].load_state_dict(value_dict)
elif USE_PEFT_BACKEND:
# In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
# on the Unet
pass
if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora:
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
state_dict=state_dict,
unet_identifier_key=self.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
)
else:
raise ValueError(
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training."
)
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to
# Now we remove any existing hooks to `_pipeline`.
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
if is_custom_diffusion and _pipeline is not None:
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
# only custom diffusion needs to set attn processors
self.set_attn_processor(attn_processors)
self.to(dtype=self.dtype, device=self.device)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def _process_custom_diffusion(self, state_dict):
from ..models.attention_processor import CustomDiffusionAttnProcessor
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
if len(value) == 0:
custom_diffusion_grouped_dict[key] = {}
else:
if "to_out" in key:
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
else:
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
for key, value_dict in custom_diffusion_grouped_dict.items():
if len(value_dict) == 0:
attn_processors[key] = CustomDiffusionAttnProcessor(
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
)
else:
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
attn_processors[key] = CustomDiffusionAttnProcessor(
train_kv=True,
train_q_out=train_q_out,
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
)
attn_processors[key].load_state_dict(value_dict)
return attn_processors
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
# This method does the following things:
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
# format. For legacy format no filtering is applied.
# 2. Converts the `state_dict` to the `peft` compatible format.
# 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the
# `LoraConfig` specs.
# 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it.
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
keys = list(state_dict.keys())
unet_keys = [k for k in keys if k.startswith(unet_identifier_key)]
unet_state_dict = {
k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)]
network_alphas = {
k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
is_model_cpu_offload = False
is_sequential_cpu_offload = False
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
if len(state_dict_to_be_used) > 0:
if adapter_name in getattr(self, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
)
state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
if network_alphas is not None:
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
# `convert_unet_state_dict_to_peft` method.
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
return is_model_cpu_offload, is_sequential_cpu_offload
@classmethod
# Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
if not USE_PEFT_BACKEND:
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# only custom diffusion needs to set attn processors
if is_custom_diffusion:
self.set_attn_processor(attn_processors)
# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
)
if is_new_lora_format:
# Strip the `"unet"` prefix.
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
if is_text_encoder_present:
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
logger.warning(warn_message)
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
# change processor format to 'pure' LoRACompatibleLinear format
if any("processor" in k.split(".") for k in state_dict.keys()):
def format_to_lora_compatible(key):
if "processor" not in key.split("."):
return key
return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")
state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}
if network_alphas is not None:
network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
return state_dict, network_alphas
return (is_model_cpu_offload, is_sequential_cpu_offload)
def save_attn_procs(
self,
@@ -460,6 +448,23 @@ class UNet2DConditionLoadersMixin:
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
is_custom_diffusion = any(
isinstance(
x,
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
)
for (_, x) in self.attn_processors.items()
)
if is_custom_diffusion:
state_dict = self._get_custom_diffusion_state_dict()
else:
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
from peft.utils import get_peft_model_state_dict
state_dict = get_peft_model_state_dict(self)
if save_function is None:
if safe_serialization:
@@ -471,36 +476,6 @@ class UNet2DConditionLoadersMixin:
os.makedirs(save_directory, exist_ok=True)
is_custom_diffusion = any(
isinstance(
x,
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
)
for (_, x) in self.attn_processors.items()
)
if is_custom_diffusion:
model_to_save = AttnProcsLayers(
{
y: x
for (y, x) in self.attn_processors.items()
if isinstance(
x,
(
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
),
)
}
)
state_dict = model_to_save.state_dict()
for name, attn in self.attn_processors.items():
if len(attn.state_dict()) == 0:
state_dict[name] = {}
else:
model_to_save = AttnProcsLayers(self.attn_processors)
state_dict = model_to_save.state_dict()
if weight_name is None:
if safe_serialization:
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
@@ -512,56 +487,84 @@ class UNet2DConditionLoadersMixin:
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
def _get_custom_diffusion_state_dict(self):
from ..models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
)
model_to_save = AttnProcsLayers(
{
y: x
for (y, x) in self.attn_processors.items()
if isinstance(
x,
(
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
),
)
}
)
state_dict = model_to_save.state_dict()
for name, attn in self.attn_processors.items():
if len(attn.state_dict()) == 0:
state_dict[name] = {}
return state_dict
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `fuse_lora()`.")
self.lora_scale = lora_scale
self._safe_fusing = safe_fusing
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
def _fuse_lora_apply(self, module, adapter_names=None):
if not USE_PEFT_BACKEND:
if hasattr(module, "_fuse_lora"):
module._fuse_lora(self.lora_scale, self._safe_fusing)
from peft.tuners.tuners_utils import BaseTunerLayer
if adapter_names is not None:
merge_kwargs = {"safe_merge": self._safe_fusing}
if isinstance(module, BaseTunerLayer):
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)
# For BC with prevous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported in your environment. Please switch"
" to PEFT backend to use this argument by installing latest PEFT and transformers."
" `pip install -U peft transformers`"
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
" to the latest version of PEFT. `pip install -U peft`"
)
else:
from peft.tuners.tuners_utils import BaseTunerLayer
merge_kwargs = {"safe_merge": self._safe_fusing}
if isinstance(module, BaseTunerLayer):
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)
# For BC with prevous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
" to the latest version of PEFT. `pip install -U peft`"
)
module.merge(**merge_kwargs)
module.merge(**merge_kwargs)
def unfuse_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
self.apply(self._unfuse_lora_apply)
def _unfuse_lora_apply(self, module):
if not USE_PEFT_BACKEND:
if hasattr(module, "_unfuse_lora"):
module._unfuse_lora()
else:
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(module, BaseTunerLayer):
module.unmerge()
if isinstance(module, BaseTunerLayer):
module.unmerge()
def unload_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unload_lora()`.")
from ..utils import recurse_remove_peft_layers
recurse_remove_peft_layers(self)
if hasattr(self, "peft_config"):
del self.peft_config
def set_adapters(
self,

View File

@@ -903,17 +903,6 @@ class UNet2DConditionModel(
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def unload_lora(self):
"""Unloads LoRA weights."""
deprecate(
"unload_lora",
"0.28.0",
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
)
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
def get_time_embed(
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
) -> Optional[torch.Tensor]:

View File

@@ -22,7 +22,7 @@ import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, logging
from ...utils import BaseOutput, logging
from ..activations import get_activation
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -546,18 +546,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
def unload_lora(self):
"""Unloads LoRA weights."""
deprecate(
"unload_lora",
"0.28.0",
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
)
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
def forward(
self,
sample: torch.Tensor,

View File

@@ -37,7 +37,9 @@ from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
is_peft_available,
load_hf_numpy,
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_accelerator_with_training,
@@ -51,11 +53,38 @@ from diffusers.utils.testing_utils import (
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
logger = logging.get_logger(__name__)
enable_full_determinism()
def get_unet_lora_config():
rank = 4
unet_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config
def check_if_lora_correctly_set(model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
def create_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
@@ -1005,6 +1034,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@require_peft_backend
def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample = model(**inputs_dict).sample
assert not torch.allclose(
non_lora_sample, lora_sample, atol=1e-4, rtol=1e-4
), "LoRA injected UNet should produce different results."
@require_peft_backend
def test_lora_serialization(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(
non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4
), "LoRA injected UNet should produce different results."
assert torch.allclose(
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
), "Loading from a saved checkpoint should produce identical results."
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):