mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] Remove legacy LoRA code and related adjustments (#8316)
* remove legacy code from load_attn_procs. * finish first draft * fix more. * fix more * add test * add serialization support. * fix-copies * require peft backend for lora tests * style * fix test * fix loading. * empty * address benjamin's feedback.
This commit is contained in:
18
.github/workflows/pr_test_peft_backend.yml
vendored
18
.github/workflows/pr_test_peft_backend.yml
vendored
@@ -111,3 +111,21 @@ jobs:
|
||||
-s -v \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/lora/
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_models_lora_${{ matrix.config.report }} \
|
||||
tests/models/ -k "lora"
|
||||
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
cat reports/tests_models_lora_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
5
.github/workflows/push_tests.yml
vendored
5
.github/workflows/push_tests.yml
vendored
@@ -189,12 +189,17 @@ jobs:
|
||||
-s -v -k "not Flax and not Onnx and not PEFTLoRALoading" \
|
||||
--make-reports=tests_peft_cuda \
|
||||
tests/lora/
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "lora and not Flax and not Onnx and not PEFTLoRALoading" \
|
||||
--make-reports=tests_peft_cuda_models_lora \
|
||||
tests/models/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_peft_cuda_stats.txt
|
||||
cat reports/tests_peft_cuda_failures_short.txt
|
||||
cat reports/tests_peft_cuda_models_lora_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
|
||||
@@ -22,17 +22,14 @@ import torch
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from .. import __version__
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||
from ..models.modeling_utils import load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
@@ -119,13 +116,10 @@ class LoraLoaderMixin:
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
)
|
||||
@@ -136,7 +130,6 @@ class LoraLoaderMixin:
|
||||
if not hasattr(self, "text_encoder")
|
||||
else self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
)
|
||||
@@ -193,16 +186,8 @@ class LoraLoaderMixin:
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# UNet and text encoder or both.
|
||||
@@ -383,9 +368,7 @@ class LoraLoaderMixin:
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(
|
||||
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||
):
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
@@ -395,14 +378,11 @@ class LoraLoaderMixin:
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
See `LoRALinearLayer` for more details.
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
@@ -410,94 +390,18 @@ class LoraLoaderMixin:
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
||||
|
||||
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
||||
if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
|
||||
unet_keys = [k for k in keys if k.startswith(cls.unet_name)]
|
||||
state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
|
||||
network_alphas = {
|
||||
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
else:
|
||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
if not USE_PEFT_BACKEND:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warning(warn_message)
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
if adapter_name in getattr(unet, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
||||
)
|
||||
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
|
||||
if network_alphas is not None:
|
||||
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
|
||||
# `convert_unet_state_dict_to_peft` method.
|
||||
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(unet)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
unet.load_attn_procs(
|
||||
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
|
||||
)
|
||||
unet.load_attn_procs(
|
||||
state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(
|
||||
@@ -507,7 +411,6 @@ class LoraLoaderMixin:
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
low_cpu_mem_usage=None,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
):
|
||||
@@ -527,11 +430,6 @@ class LoraLoaderMixin:
|
||||
lora_scale (`float`):
|
||||
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
||||
lora layer.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
@@ -541,8 +439,6 @@ class LoraLoaderMixin:
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
@@ -625,9 +521,7 @@ class LoraLoaderMixin:
|
||||
# Unsafe code />
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||
):
|
||||
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
@@ -640,19 +534,12 @@ class LoraLoaderMixin:
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
||||
@@ -846,22 +733,11 @@ class LoraLoaderMixin:
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
logger.warning(
|
||||
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
|
||||
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
|
||||
)
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for _, module in unet.named_modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
else:
|
||||
recurse_remove_peft_layers(unet)
|
||||
if hasattr(unet, "peft_config"):
|
||||
del unet.peft_config
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.unload_lora()
|
||||
|
||||
# Safe to call the following regardless of LoRA.
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
|
||||
@@ -33,34 +33,32 @@ from ..models.embeddings import (
|
||||
IPAdapterPlusImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
|
||||
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_version,
|
||||
is_torch_version,
|
||||
logging,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
TEXT_ENCODER_NAME = "text_encoder"
|
||||
UNET_NAME = "unet"
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
||||
|
||||
@@ -79,7 +77,8 @@ class UNet2DConditionLoadersMixin:
|
||||
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
|
||||
defined in
|
||||
[`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
|
||||
and be a `torch.nn.Module` class.
|
||||
and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install
|
||||
`peft`: `pip install -U peft`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
@@ -110,20 +109,20 @@ class UNet2DConditionLoadersMixin:
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
information.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
adapter_name (`str`, *optional*, defaults to None):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -139,9 +138,6 @@ class UNet2DConditionLoadersMixin:
|
||||
)
|
||||
```
|
||||
"""
|
||||
from ..models.attention_processor import CustomDiffusionAttnProcessor
|
||||
from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
@@ -152,15 +148,9 @@ class UNet2DConditionLoadersMixin:
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
|
||||
adapter_name = kwargs.pop("adapter_name", None)
|
||||
_pipeline = kwargs.pop("_pipeline", None)
|
||||
|
||||
is_network_alphas_none = network_alphas is None
|
||||
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
allow_pickle = False
|
||||
|
||||
if use_safetensors is None:
|
||||
@@ -216,198 +206,196 @@ class UNet2DConditionLoadersMixin:
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
# fill attn processors
|
||||
lora_layers_list = []
|
||||
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
|
||||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if is_lora:
|
||||
# correct keys
|
||||
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
|
||||
|
||||
if network_alphas is not None:
|
||||
network_alphas_keys = list(network_alphas.keys())
|
||||
used_network_alphas_keys = set()
|
||||
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
mapped_network_alphas = {}
|
||||
|
||||
all_keys = list(state_dict.keys())
|
||||
for key in all_keys:
|
||||
value = state_dict.pop(key)
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
lora_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
|
||||
if network_alphas is not None:
|
||||
for k in network_alphas_keys:
|
||||
if k.replace(".alpha", "") in key:
|
||||
mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
|
||||
used_network_alphas_keys.add(k)
|
||||
|
||||
if not is_network_alphas_none:
|
||||
if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
|
||||
raise ValueError(
|
||||
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
|
||||
)
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(
|
||||
f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key, value_dict in lora_grouped_dict.items():
|
||||
attn_processor = self
|
||||
for sub_key in key.split("."):
|
||||
attn_processor = getattr(attn_processor, sub_key)
|
||||
|
||||
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
|
||||
# or add_{k,v,q,out_proj}_proj_lora layers.
|
||||
rank = value_dict["lora.down.weight"].shape[0]
|
||||
|
||||
if isinstance(attn_processor, LoRACompatibleConv):
|
||||
in_features = attn_processor.in_channels
|
||||
out_features = attn_processor.out_channels
|
||||
kernel_size = attn_processor.kernel_size
|
||||
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
lora = LoRAConv2dLayer(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
rank=rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=attn_processor.stride,
|
||||
padding=attn_processor.padding,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
elif isinstance(attn_processor, LoRACompatibleLinear):
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
lora = LoRALinearLayer(
|
||||
attn_processor.in_features,
|
||||
attn_processor.out_features,
|
||||
rank,
|
||||
mapped_network_alphas.get(key),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
|
||||
|
||||
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
|
||||
lora_layers_list.append((attn_processor, lora))
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(value_dict.values())).device
|
||||
dtype = next(iter(value_dict.values())).dtype
|
||||
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
lora.load_state_dict(value_dict)
|
||||
|
||||
elif is_custom_diffusion:
|
||||
attn_processors = {}
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
if len(value) == 0:
|
||||
custom_diffusion_grouped_dict[key] = {}
|
||||
else:
|
||||
if "to_out" in key:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
else:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
||||
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
for key, value_dict in custom_diffusion_grouped_dict.items():
|
||||
if len(value_dict) == 0:
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
||||
)
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
||||
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
||||
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=True,
|
||||
train_q_out=train_q_out,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
elif USE_PEFT_BACKEND:
|
||||
# In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
|
||||
# on the Unet
|
||||
pass
|
||||
if is_custom_diffusion:
|
||||
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
||||
elif is_lora:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
|
||||
state_dict=state_dict,
|
||||
unet_identifier_key=self.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
||||
f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training."
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
||||
# Now we remove any existing hooks to
|
||||
# Now we remove any existing hooks to `_pipeline`.
|
||||
|
||||
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
|
||||
if is_custom_diffusion and _pipeline is not None:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
# only custom diffusion needs to set attn processors
|
||||
self.set_attn_processor(attn_processors)
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
def _process_custom_diffusion(self, state_dict):
|
||||
from ..models.attention_processor import CustomDiffusionAttnProcessor
|
||||
|
||||
attn_processors = {}
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
if len(value) == 0:
|
||||
custom_diffusion_grouped_dict[key] = {}
|
||||
else:
|
||||
if "to_out" in key:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
else:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
||||
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
for key, value_dict in custom_diffusion_grouped_dict.items():
|
||||
if len(value_dict) == 0:
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
||||
)
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
||||
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
||||
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=True,
|
||||
train_q_out=train_q_out,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
|
||||
return attn_processors
|
||||
|
||||
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
|
||||
# This method does the following things:
|
||||
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
|
||||
# format. For legacy format no filtering is applied.
|
||||
# 2. Converts the `state_dict` to the `peft` compatible format.
|
||||
# 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the
|
||||
# `LoraConfig` specs.
|
||||
# 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it.
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
unet_keys = [k for k in keys if k.startswith(unet_identifier_key)]
|
||||
unet_state_dict = {
|
||||
k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys
|
||||
}
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)]
|
||||
network_alphas = {
|
||||
k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
|
||||
|
||||
if len(state_dict_to_be_used) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
||||
)
|
||||
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
|
||||
|
||||
if network_alphas is not None:
|
||||
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
|
||||
# `convert_unet_state_dict_to_peft` method.
|
||||
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet`
|
||||
if not USE_PEFT_BACKEND:
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
# only custom diffusion needs to set attn processors
|
||||
if is_custom_diffusion:
|
||||
self.set_attn_processor(attn_processors)
|
||||
|
||||
# set lora layers
|
||||
for target_module, lora_layer in lora_layers_list:
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
|
||||
is_new_lora_format = all(
|
||||
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
||||
)
|
||||
if is_new_lora_format:
|
||||
# Strip the `"unet"` prefix.
|
||||
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
||||
if is_text_encoder_present:
|
||||
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
||||
logger.warning(warn_message)
|
||||
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
||||
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
||||
|
||||
# change processor format to 'pure' LoRACompatibleLinear format
|
||||
if any("processor" in k.split(".") for k in state_dict.keys()):
|
||||
|
||||
def format_to_lora_compatible(key):
|
||||
if "processor" not in key.split("."):
|
||||
return key
|
||||
return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")
|
||||
|
||||
state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}
|
||||
|
||||
if network_alphas is not None:
|
||||
network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
|
||||
return state_dict, network_alphas
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
def save_attn_procs(
|
||||
self,
|
||||
@@ -460,6 +448,23 @@ class UNet2DConditionLoadersMixin:
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
is_custom_diffusion = any(
|
||||
isinstance(
|
||||
x,
|
||||
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
|
||||
)
|
||||
for (_, x) in self.attn_processors.items()
|
||||
)
|
||||
if is_custom_diffusion:
|
||||
state_dict = self._get_custom_diffusion_state_dict()
|
||||
else:
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
|
||||
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
state_dict = get_peft_model_state_dict(self)
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
@@ -471,36 +476,6 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
is_custom_diffusion = any(
|
||||
isinstance(
|
||||
x,
|
||||
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
|
||||
)
|
||||
for (_, x) in self.attn_processors.items()
|
||||
)
|
||||
if is_custom_diffusion:
|
||||
model_to_save = AttnProcsLayers(
|
||||
{
|
||||
y: x
|
||||
for (y, x) in self.attn_processors.items()
|
||||
if isinstance(
|
||||
x,
|
||||
(
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
}
|
||||
)
|
||||
state_dict = model_to_save.state_dict()
|
||||
for name, attn in self.attn_processors.items():
|
||||
if len(attn.state_dict()) == 0:
|
||||
state_dict[name] = {}
|
||||
else:
|
||||
model_to_save = AttnProcsLayers(self.attn_processors)
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
|
||||
@@ -512,56 +487,84 @@ class UNet2DConditionLoadersMixin:
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
def _get_custom_diffusion_state_dict(self):
|
||||
from ..models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
model_to_save = AttnProcsLayers(
|
||||
{
|
||||
y: x
|
||||
for (y, x) in self.attn_processors.items()
|
||||
if isinstance(
|
||||
x,
|
||||
(
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
}
|
||||
)
|
||||
state_dict = model_to_save.state_dict()
|
||||
for name, attn in self.attn_processors.items():
|
||||
if len(attn.state_dict()) == 0:
|
||||
state_dict[name] = {}
|
||||
|
||||
return state_dict
|
||||
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `fuse_lora()`.")
|
||||
|
||||
self.lora_scale = lora_scale
|
||||
self._safe_fusing = safe_fusing
|
||||
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
|
||||
|
||||
def _fuse_lora_apply(self, module, adapter_names=None):
|
||||
if not USE_PEFT_BACKEND:
|
||||
if hasattr(module, "_fuse_lora"):
|
||||
module._fuse_lora(self.lora_scale, self._safe_fusing)
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if adapter_names is not None:
|
||||
merge_kwargs = {"safe_merge": self._safe_fusing}
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if self.lora_scale != 1.0:
|
||||
module.scale_layer(self.lora_scale)
|
||||
|
||||
# For BC with prevous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported in your environment. Please switch"
|
||||
" to PEFT backend to use this argument by installing latest PEFT and transformers."
|
||||
" `pip install -U peft transformers`"
|
||||
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
||||
" to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
merge_kwargs = {"safe_merge": self._safe_fusing}
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if self.lora_scale != 1.0:
|
||||
module.scale_layer(self.lora_scale)
|
||||
|
||||
# For BC with prevous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
||||
" to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
def unfuse_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
|
||||
self.apply(self._unfuse_lora_apply)
|
||||
|
||||
def _unfuse_lora_apply(self, module):
|
||||
if not USE_PEFT_BACKEND:
|
||||
if hasattr(module, "_unfuse_lora"):
|
||||
module._unfuse_lora()
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
def unload_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
||||
|
||||
from ..utils import recurse_remove_peft_layers
|
||||
|
||||
recurse_remove_peft_layers(self)
|
||||
if hasattr(self, "peft_config"):
|
||||
del self.peft_config
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
|
||||
@@ -903,17 +903,6 @@ class UNet2DConditionModel(
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def unload_lora(self):
|
||||
"""Unloads LoRA weights."""
|
||||
deprecate(
|
||||
"unload_lora",
|
||||
"0.28.0",
|
||||
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
|
||||
)
|
||||
for module in self.modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
|
||||
def get_time_embed(
|
||||
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, deprecate, logging
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -546,18 +546,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
|
||||
def unload_lora(self):
|
||||
"""Unloads LoRA weights."""
|
||||
deprecate(
|
||||
"unload_lora",
|
||||
"0.28.0",
|
||||
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
|
||||
)
|
||||
for module in self.modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -37,7 +37,9 @@ from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
is_peft_available,
|
||||
load_hf_numpy,
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
require_torch_accelerator_with_training,
|
||||
@@ -51,11 +53,38 @@ from diffusers.utils.testing_utils import (
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def get_unet_lora_config():
|
||||
rank = 4
|
||||
unet_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=rank,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return unet_lora_config
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Checks if the LoRA layers are correctly set with peft
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def create_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
@@ -1005,6 +1034,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(
|
||||
non_lora_sample, lora_sample, atol=1e-4, rtol=1e-4
|
||||
), "LoRA injected UNet should produce different results."
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora_serialization(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample_1 = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(
|
||||
non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4
|
||||
), "LoRA injected UNet should produce different results."
|
||||
assert torch.allclose(
|
||||
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
|
||||
), "Loading from a saved checkpoint should produce identical results."
|
||||
|
||||
|
||||
@slow
|
||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user