mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Revert "[LoRA] introduce LoraBaseMixin to promote reusability." (#8773)
Revert "[LoRA] introduce `LoraBaseMixin` to promote reusability. (#8670)"
This reverts commit a2071a1837.
This commit is contained in:
@@ -12,13 +12,10 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# LoRA
|
||||
|
||||
LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the denoiser, text encoder or both. The denoiser usually corresponds to a UNet ([`UNet2DConditionModel`], for example) or a Transformer ([`SD3Transformer2DModel`], for example). There are several classes for loading LoRA weights:
|
||||
LoRA is a fast and lightweight training method that inserts and trains a significantly smaller number of parameters instead of all the model parameters. This produces a smaller file (~100 MBs) and makes it easier to quickly train a model to learn a new concept. LoRA weights are typically loaded into the UNet, text encoder or both. There are two classes for loading LoRA weights:
|
||||
|
||||
- [`LoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.
|
||||
- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`LoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.
|
||||
- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -32,16 +29,4 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
## StableDiffusionXLLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin
|
||||
|
||||
## SD3LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora.SD3LoraLoaderMixin
|
||||
|
||||
## AmusedLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora.AmusedLoraLoaderMixin
|
||||
|
||||
## LoraBaseMixin
|
||||
|
||||
[[autodoc]] loaders.lora_base.LoraBaseMixin
|
||||
[[autodoc]] loaders.lora.StableDiffusionXLLoraLoaderMixin
|
||||
@@ -41,7 +41,7 @@ from transformers import (
|
||||
|
||||
import diffusers.optimization
|
||||
from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel
|
||||
from diffusers.loaders import AmusedLoraLoaderMixin
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.utils import is_wandb_available
|
||||
|
||||
|
||||
@@ -532,7 +532,7 @@ def main(args):
|
||||
weights.pop()
|
||||
|
||||
if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None:
|
||||
AmusedLoraLoaderMixin.save_lora_weights(
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
@@ -566,11 +566,11 @@ def main(args):
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer is not None or text_encoder_ is not None:
|
||||
lora_state_dict, network_alphas = AmusedLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
AmusedLoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
|
||||
)
|
||||
AmusedLoraLoaderMixin.load_lora_into_transformer(
|
||||
LoraLoaderMixin.load_lora_into_transformer(
|
||||
lora_state_dict, network_alphas=network_alphas, transformer=transformer
|
||||
)
|
||||
|
||||
|
||||
@@ -55,18 +55,11 @@ _import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
|
||||
_import_structure["transformer_sd3"] = ["SD3TransformerLoadersMixin"]
|
||||
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
if is_transformers_available():
|
||||
_import_structure["single_file"] = ["FromSingleFileMixin"]
|
||||
_import_structure["lora"] = [
|
||||
"AmusedLoraLoaderMixin",
|
||||
"LoraLoaderMixin",
|
||||
"SD3LoraLoaderMixin",
|
||||
"StableDiffusionXLLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "SD3LoraLoaderMixin"]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
||||
|
||||
@@ -76,18 +69,12 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .single_file_model import FromOriginalModelMixin
|
||||
from .transformer_sd3 import SD3TransformerLoadersMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
if is_transformers_available():
|
||||
from .ip_adapter import IPAdapterMixin
|
||||
from .lora import (
|
||||
AmusedLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
SD3LoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
)
|
||||
from .lora import LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
|
||||
from .single_file import FromSingleFileMixin
|
||||
from .textual_inversion import TextualInversionLoaderMixin
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,938 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
|
||||
from ..models.modeling_utils import load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
"""Utility class for handling LoRAs."""
|
||||
|
||||
is_unet_denoiser = False
|
||||
is_transformer_denoiser = False
|
||||
num_fused_loras = 0
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
if hasattr(self, "text_encoder"):
|
||||
recurse_remove_peft_layers(self.text_encoder)
|
||||
# TODO: @younesbelkada handle this in transformers side
|
||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
recurse_remove_peft_layers(self.text_encoder_2)
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
|
||||
@classmethod
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
@classmethod
|
||||
def _fetch_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weight_name,
|
||||
use_safetensors,
|
||||
local_files_only,
|
||||
cache_dir,
|
||||
force_download,
|
||||
resume_download,
|
||||
proxies,
|
||||
token,
|
||||
revision,
|
||||
subfolder,
|
||||
user_agent,
|
||||
allow_pickle,
|
||||
):
|
||||
from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
# Here we're relaxing the loading check to enable more Inference API
|
||||
# friendliness where sometimes, it's not at all possible to automatically
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
file_extension=".safetensors",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(
|
||||
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
||||
):
|
||||
from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
if local_files_only or HF_HUB_OFFLINE:
|
||||
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
||||
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
return
|
||||
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
||||
targeted_files = [
|
||||
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
|
||||
]
|
||||
else:
|
||||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
||||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
||||
if len(targeted_files) == 0:
|
||||
return
|
||||
|
||||
# "scheduler" does not correspond to a LoRA checkpoint.
|
||||
# "optimizer" does not correspond to a LoRA checkpoint
|
||||
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
||||
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
||||
targeted_files = list(
|
||||
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
||||
)
|
||||
|
||||
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
||||
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
||||
|
||||
if len(targeted_files) > 1:
|
||||
raise ValueError(
|
||||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
||||
)
|
||||
weight_name = targeted_files[0]
|
||||
return weight_name
|
||||
|
||||
def load_lora_weights(self, **kwargs):
|
||||
raise NotImplementedError("`load_lora_weights()` is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(cls, **kwargs):
|
||||
raise NotImplementedError("`save_lora_weights()` not implemented.")
|
||||
|
||||
@classmethod
|
||||
def lora_state_dict(cls, **kwargs):
|
||||
raise NotImplementedError("`lora_state_dict()` is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(
|
||||
cls,
|
||||
state_dict,
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
See `LoRALinearLayer` for more details.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The text encoder model to load the LoRA layers into.
|
||||
prefix (`str`):
|
||||
Expected prefix of the `text_encoder` in the `state_dict`.
|
||||
lora_scale (`float`):
|
||||
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
||||
lora layer.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if any(cls.text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in text_encoder_lora_state_dict:
|
||||
continue
|
||||
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [
|
||||
k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix
|
||||
]
|
||||
network_alphas = {
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
def unload_lora_weights(self):
|
||||
"""
|
||||
Unloads the LoRA parameters.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
||||
>>> pipeline.unload_lora_weights()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if self.is_unet_denoiser:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.unload_lora()
|
||||
elif self.is_transformer_denoiser:
|
||||
transformer = (
|
||||
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
)
|
||||
transformer.unload_lora()
|
||||
else:
|
||||
raise ValueError("No valid denoiser found in the network.")
|
||||
|
||||
# Safe to call the following regardless of LoRA.
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
fuse_denoiser: bool = True,
|
||||
fuse_text_encoder: bool = True,
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
fuse_denoiser (`bool`, defaults to `True`):
|
||||
Whether to fuse the denoiser (UNet, Transformer, etc.) LoRA parameters.
|
||||
fuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
fuse_unet = True if fuse_denoiser and self.is_unet_denoiser else False
|
||||
fuse_transformer = True if fuse_denoiser and self.is_transformer_denoiser else False
|
||||
|
||||
if fuse_unet:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
elif fuse_transformer:
|
||||
transformer = (
|
||||
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
)
|
||||
transformer.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing, adapter_names=adapter_names)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing, adapter_names=adapter_names)
|
||||
|
||||
if fuse_denoiser or fuse_text_encoder:
|
||||
self.num_fused_loras += 1
|
||||
|
||||
def unfuse_lora(self, unfuse_denoiser: bool = True, unfuse_text_encoder: bool = True):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
unfuse_unet = True if unfuse_denoiser and self.is_unet_denoiser else False
|
||||
unfuse_transformer = True if unfuse_denoiser and self.is_transformer_denoiser else False
|
||||
|
||||
if unfuse_unet:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
for module in unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
elif unfuse_transformer:
|
||||
transformer = (
|
||||
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
)
|
||||
for module in transformer.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
if unfuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
unfuse_text_encoder_lora(self.text_encoder)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
unfuse_text_encoder_lora(self.text_encoder_2)
|
||||
|
||||
self.num_fused_loras -= 1
|
||||
|
||||
def set_adapters_for_text_encoder(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the text encoder.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
text_encoder_weights (`List[float]`, *optional*):
|
||||
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
# Expand weights into a list, one entry per adapter
|
||||
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
|
||||
if not isinstance(weights, list):
|
||||
weights = [weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
|
||||
# Set None values to default of 1.0
|
||||
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
|
||||
weights = [w if w is not None else 1.0 for w in weights]
|
||||
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
|
||||
)
|
||||
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
||||
|
||||
def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
|
||||
"""
|
||||
Disables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
|
||||
`text_encoder` attribute.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(text_encoder, enabled=False)
|
||||
|
||||
def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
|
||||
"""
|
||||
Enables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(self.text_encoder, enabled=True)
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
||||
):
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
|
||||
adapter_weights = copy.deepcopy(adapter_weights)
|
||||
|
||||
# Expand weights into a list, one entry per adapter
|
||||
if not isinstance(adapter_weights, list):
|
||||
adapter_weights = [adapter_weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(adapter_weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
||||
)
|
||||
|
||||
# Decompose weights into weights for unet, text_encoder and text_encoder_2
|
||||
denoiser_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], []
|
||||
|
||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||
all_adapters = {
|
||||
adapter for adapters in list_adapters.values() for adapter in adapters
|
||||
} # eg ["adapter1", "adapter2"]
|
||||
invert_list_adapters = {
|
||||
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
||||
for adapter in all_adapters
|
||||
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
|
||||
denoiser_name = "unet" if self.is_unet_denoiser else "transformer"
|
||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||
if isinstance(weights, dict):
|
||||
denoiser_lora_weight = weights.pop(denoiser_name, None)
|
||||
text_encoder_lora_weight = weights.pop("text_encoder", None)
|
||||
text_encoder_2_lora_weight = weights.pop("text_encoder_2", None)
|
||||
|
||||
if len(weights) > 0:
|
||||
raise ValueError(
|
||||
f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}."
|
||||
)
|
||||
|
||||
if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"):
|
||||
logger.warning(
|
||||
"Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2."
|
||||
)
|
||||
|
||||
# warn if adapter doesn't have parts specified by adapter_weights
|
||||
for part_weight, part_name in zip(
|
||||
[denoiser_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight],
|
||||
[denoiser_name, "text_encoder", "text_encoder_2"],
|
||||
):
|
||||
if part_weight is not None and part_name not in invert_list_adapters[adapter_name]:
|
||||
logger.warning(
|
||||
f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
|
||||
)
|
||||
|
||||
else:
|
||||
denoiser_lora_weight = weights
|
||||
text_encoder_lora_weight = weights
|
||||
text_encoder_2_lora_weight = weights
|
||||
|
||||
denoiser_lora_weights.append(denoiser_lora_weight)
|
||||
text_encoder_lora_weights.append(text_encoder_lora_weight)
|
||||
text_encoder_2_lora_weights.append(text_encoder_2_lora_weight)
|
||||
|
||||
if denoiser_name == "unet":
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
# Handle the UNET
|
||||
unet.set_adapters(adapter_names, denoiser_lora_weights)
|
||||
else:
|
||||
transformer = (
|
||||
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
)
|
||||
# Handle the UNET
|
||||
transformer.set_adapters(adapter_names, denoiser_lora_weights)
|
||||
|
||||
# Handle the Text Encoder
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights)
|
||||
|
||||
def disable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# Disable denoiser adapters
|
||||
if self.is_unet_denoiser:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.disable_lora()
|
||||
else:
|
||||
transformer = (
|
||||
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
)
|
||||
transformer.disable_lora()
|
||||
|
||||
# Disable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.disable_lora_for_text_encoder(self.text_encoder)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
self.disable_lora_for_text_encoder(self.text_encoder_2)
|
||||
|
||||
def enable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# Enable unet adapters
|
||||
if self.is_unet_denoiser:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.enable_lora()
|
||||
else:
|
||||
transformer = (
|
||||
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
)
|
||||
transformer.enable_lora()
|
||||
|
||||
# Enable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.enable_lora_for_text_encoder(self.text_encoder)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
self.enable_lora_for_text_encoder(self.text_encoder_2)
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
"""
|
||||
Args:
|
||||
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
The names of the adapter to delete. Can be a single string or a list of strings
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if isinstance(adapter_names, str):
|
||||
adapter_names = [adapter_names]
|
||||
|
||||
# Delete unet adapters
|
||||
if self.is_unet_denoiser:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.delete_adapters(adapter_names)
|
||||
else:
|
||||
transformer = (
|
||||
getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
)
|
||||
transformer.delete_adapters(adapter_names)
|
||||
|
||||
for adapter_name in adapter_names:
|
||||
# Delete text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
delete_adapter_layers(self.text_encoder, adapter_name)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
delete_adapter_layers(self.text_encoder_2, adapter_name)
|
||||
|
||||
def get_active_adapters(self) -> List[str]:
|
||||
"""
|
||||
Gets the list of the current active adapters.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
pipeline.get_active_adapters()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
active_adapters = []
|
||||
if self.is_unet_denoiser:
|
||||
denoiser = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
else:
|
||||
denoiser = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
|
||||
for module in denoiser.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
break
|
||||
|
||||
return active_adapters
|
||||
|
||||
def get_list_adapters(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Gets the current list of all available adapters in the pipeline.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
set_adapters = {}
|
||||
|
||||
if hasattr(self, "text_encoder") and hasattr(self.text_encoder, "peft_config"):
|
||||
set_adapters["text_encoder"] = list(self.text_encoder.peft_config.keys())
|
||||
|
||||
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
|
||||
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
|
||||
|
||||
if self.is_unet_denoiser:
|
||||
denoiser = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
denoiser_name = self.unet_name
|
||||
else:
|
||||
denoiser = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
denoiser_name = self.transformer_name
|
||||
|
||||
if hasattr(self, denoiser_name) and hasattr(denoiser, "peft_config"):
|
||||
set_adapters[denoiser_name] = (
|
||||
list(self.unet.peft_config.keys())
|
||||
if self.is_unet_denoiser
|
||||
else list(self.transformer.peft_config.keys())
|
||||
)
|
||||
|
||||
return set_adapters
|
||||
|
||||
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
||||
"""
|
||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||
you want to load multiple adapters and free some GPU memory.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]`):
|
||||
List of adapters to send device to.
|
||||
device (`Union[torch.device, str, int]`):
|
||||
Device to send the adapters to. Can be either a torch device, a str or an integer.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
# Handle the denoiser
|
||||
if self.is_unet_denoiser:
|
||||
denoiser = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
else:
|
||||
denoiser = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
|
||||
for denoiser_module in denoiser.modules():
|
||||
if isinstance(denoiser_module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
denoiser_module.lora_A[adapter_name].to(device)
|
||||
denoiser_module.lora_B[adapter_name].to(device)
|
||||
# this is a param, not a module, so device placement is not in-place -> re-assign
|
||||
if (
|
||||
hasattr(denoiser_module, "lora_magnitude_vector")
|
||||
and denoiser_module.lora_magnitude_vector is not None
|
||||
):
|
||||
denoiser_module.lora_magnitude_vector[adapter_name] = denoiser_module.lora_magnitude_vector[
|
||||
adapter_name
|
||||
].to(device)
|
||||
|
||||
# Handle the text encoder
|
||||
modules_to_process = []
|
||||
if hasattr(self, "text_encoder"):
|
||||
modules_to_process.append(self.text_encoder)
|
||||
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
modules_to_process.append(self.text_encoder_2)
|
||||
|
||||
for text_encoder in modules_to_process:
|
||||
# loop over submodules
|
||||
for text_encoder_module in text_encoder.modules():
|
||||
if isinstance(text_encoder_module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
text_encoder_module.lora_A[adapter_name].to(device)
|
||||
text_encoder_module.lora_B[adapter_name].to(device)
|
||||
# this is a param, not a module, so device placement is not in-place -> re-assign
|
||||
if (
|
||||
hasattr(text_encoder_module, "lora_magnitude_vector")
|
||||
and text_encoder_module.lora_magnitude_vector is not None
|
||||
):
|
||||
text_encoder_module.lora_magnitude_vector[
|
||||
adapter_name
|
||||
] = text_encoder_module.lora_magnitude_vector[adapter_name].to(device)
|
||||
|
||||
@staticmethod
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
save_directory: str,
|
||||
is_main_process: bool,
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
# if _lora_scale has not been set, return 1
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
@@ -1,261 +0,0 @@
|
||||
import inspect
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
delete_adapter_layers,
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .lora import TEXT_ENCODER_NAME, TRANSFORMER_NAME
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class SD3TransformerLoadersMixin:
|
||||
"""
|
||||
Load LoRA layers into a [`SD3Transformer2DModel`].
|
||||
"""
|
||||
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
# Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.fuse_lora
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `fuse_lora()`.")
|
||||
|
||||
self.lora_scale = lora_scale
|
||||
self._safe_fusing = safe_fusing
|
||||
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
|
||||
|
||||
# Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin._fuse_lora_apply
|
||||
def _fuse_lora_apply(self, module, adapter_names=None):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
merge_kwargs = {"safe_merge": self._safe_fusing}
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if self.lora_scale != 1.0:
|
||||
module.scale_layer(self.lora_scale)
|
||||
|
||||
# For BC with prevous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
||||
" to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
# Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.unfuse_lora
|
||||
def unfuse_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
|
||||
self.apply(self._unfuse_lora_apply)
|
||||
|
||||
# Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin._unfuse_lora_apply
|
||||
def _unfuse_lora_apply(self, module):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
# Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.unload_lora
|
||||
def unload_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
||||
|
||||
from ..utils import recurse_remove_peft_layers
|
||||
|
||||
recurse_remove_peft_layers(self)
|
||||
if hasattr(self, "peft_config"):
|
||||
del self.peft_config
|
||||
|
||||
# This class is almost the same but it doesn't do `_maybe_expand_lora_scales()` yet. We will work on adding
|
||||
# this support in a future PR.
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
|
||||
):
|
||||
"""
|
||||
Set the currently active adapters for use in the Transformer.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
adapter_weights (`Union[List[float], float]`, *optional*):
|
||||
The adapter(s) weights to use with the Transformer. If `None`, the weights are set to `1.0` for all the
|
||||
adapters.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights(
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
||||
)
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `set_adapters()`.")
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
|
||||
# Expand weights into a list, one entry per adapter
|
||||
# examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
|
||||
if not isinstance(weights, list):
|
||||
weights = [weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
|
||||
)
|
||||
|
||||
# Set None values to default of 1.0
|
||||
# e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
|
||||
weights = [w if w is not None else 1.0 for w in weights]
|
||||
|
||||
set_weights_and_activate_adapters(self, adapter_names, weights)
|
||||
|
||||
# Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.disable_lora with UNet->Transformer
|
||||
def disable_lora(self):
|
||||
"""
|
||||
Disable the Transformer's active LoRA layers.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights(
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
||||
)
|
||||
pipeline.disable_lora()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
set_adapter_layers(self, enabled=False)
|
||||
|
||||
# Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.enable_lora with UNet->Transformer
|
||||
def enable_lora(self):
|
||||
"""
|
||||
Enable the Transformer's active LoRA layers.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights(
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
||||
)
|
||||
pipeline.enable_lora()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
set_adapter_layers(self, enabled=True)
|
||||
|
||||
# Copied from diffusers.loaders.unet.UNet2DConditionLoadersMixin.delete_adapters with UNet->Transformer
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
"""
|
||||
Delete an adapter's LoRA layers from the Transformer.
|
||||
|
||||
Args:
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
The names (single string or list of strings) of the adapter to delete.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights(
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
|
||||
)
|
||||
pipeline.delete_adapters("cinematic")
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if isinstance(adapter_names, str):
|
||||
adapter_names = [adapter_names]
|
||||
|
||||
for adapter_name in adapter_names:
|
||||
delete_adapter_layers(self, adapter_name)
|
||||
|
||||
# Pop also the corresponding adapter from the config
|
||||
if hasattr(self, "peft_config"):
|
||||
self.peft_config.pop(adapter_name, None)
|
||||
@@ -362,7 +362,7 @@ class UNet2DConditionLoadersMixin:
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
# Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
@@ -13,13 +13,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3TransformerLoadersMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention import JointTransformerBlock
|
||||
from ...models.attention_processor import Attention, AttentionProcessor
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
@@ -32,9 +34,7 @@ from ..modeling_outputs import Transformer2DModelOutput
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SD3Transformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3TransformerLoadersMixin
|
||||
):
|
||||
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
The Transformer model introduced in Stable Diffusion 3.
|
||||
|
||||
@@ -241,6 +241,47 @@ class SD3Transformer2DModel(
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `fuse_lora()`.")
|
||||
|
||||
self.lora_scale = lora_scale
|
||||
self._safe_fusing = safe_fusing
|
||||
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
|
||||
|
||||
def _fuse_lora_apply(self, module, adapter_names=None):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
merge_kwargs = {"safe_merge": self._safe_fusing}
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if self.lora_scale != 1.0:
|
||||
module.scale_layer(self.lora_scale)
|
||||
|
||||
# For BC with prevous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
||||
" to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
def unfuse_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
|
||||
self.apply(self._unfuse_lora_apply)
|
||||
|
||||
def _unfuse_lora_apply(self, module):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
|
||||
@@ -30,12 +30,9 @@ from ...models.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from ...models.transformers import SD3Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -349,7 +346,6 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
max_sequence_length: int = 256,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
@@ -395,22 +391,9 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
@@ -513,16 +496,6 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
|
||||
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
||||
)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -29,12 +29,9 @@ from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import SD3Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -332,7 +329,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
max_sequence_length: int = 256,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
@@ -378,22 +374,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
@@ -496,16 +479,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
||||
)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
@@ -814,9 +787,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
@@ -838,7 +808,6 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
|
||||
clip_skip=self.clip_skip,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
|
||||
@@ -25,17 +25,13 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import SD3LoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import SD3Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -350,7 +346,6 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
max_sequence_length: int = 256,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
@@ -396,22 +391,9 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
@@ -514,16 +496,6 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
||||
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
||||
)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
|
||||
@@ -12,55 +12,377 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SD3Transformer2DModel,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
pass
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
from utils import check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class SD3LoRATests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler()
|
||||
scheduler_kwargs = {}
|
||||
transformer_kwargs = {
|
||||
"sample_size": 32,
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 4,
|
||||
"caption_projection_dim": 32,
|
||||
"joint_attention_dim": 32,
|
||||
"pooled_projection_dim": 64,
|
||||
"out_channels": 4,
|
||||
}
|
||||
vae_kwargs = {
|
||||
"sample_size": 32,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"block_out_channels": (4,),
|
||||
"layers_per_block": 1,
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": 1,
|
||||
"use_quant_conv": False,
|
||||
"use_post_quant_conv": False,
|
||||
"shift_factor": 0.0609,
|
||||
"scaling_factor": 1.5035,
|
||||
}
|
||||
has_three_text_encoders = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = SD3Transformer2DModel(
|
||||
sample_size=32,
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=1,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=4,
|
||||
caption_projection_dim=32,
|
||||
joint_attention_dim=32,
|
||||
pooled_projection_dim=64,
|
||||
out_channels=4,
|
||||
)
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=4,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0609,
|
||||
scaling_factor=1.5035,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"text_encoder_3": text_encoder_3,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"tokenizer_3": tokenizer_3,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def get_lora_config_for_transformer(self):
|
||||
lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return lora_config
|
||||
|
||||
def get_lora_config_for_text_encoders(self):
|
||||
text_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
return text_lora_config
|
||||
|
||||
def test_simple_inference_with_transformer_lora_save_load(self):
|
||||
components = self.get_dummy_components()
|
||||
transformer_config = self.get_lora_config_for_transformer()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
pipe.transformer.add_adapter(transformer_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
images_lora = pipe(**inputs).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
transformer_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
transformer_lora_layers=transformer_state_dict,
|
||||
)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
images_lora_from_pretrained = pipe(**inputs).images
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_clip_encoders_lora_save_load(self):
|
||||
components = self.get_dummy_components()
|
||||
transformer_config = self.get_lora_config_for_transformer()
|
||||
text_encoder_config = self.get_lora_config_for_text_encoders()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
pipe.transformer.add_adapter(transformer_config)
|
||||
pipe.text_encoder.add_adapter(text_encoder_config)
|
||||
pipe.text_encoder_2.add_adapter(text_encoder_config)
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2.")
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
images_lora = pipe(**inputs).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
transformer_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
text_encoder_one_state_dict = get_peft_model_state_dict(pipe.text_encoder)
|
||||
text_encoder_two_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
|
||||
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
transformer_lora_layers=transformer_state_dict,
|
||||
text_encoder_lora_layers=text_encoder_one_state_dict,
|
||||
text_encoder_2_lora_layers=text_encoder_two_state_dict,
|
||||
)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
images_lora_from_pretrained = pipe(**inputs).images
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two")
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_transformer_lora_and_scale(self):
|
||||
components = self.get_dummy_components()
|
||||
transformer_lora_config = self.get_lora_config_for_transformer()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_no_lora = pipe(**inputs).images
|
||||
|
||||
pipe.transformer.add_adapter(transformer_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_lora = pipe(**inputs).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
|
||||
"Lora + scale should change the output",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
|
||||
"Lora + 0 scale should lead to same result as no LoRA",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_clip_encoders_lora_and_scale(self):
|
||||
components = self.get_dummy_components()
|
||||
transformer_lora_config = self.get_lora_config_for_transformer()
|
||||
text_encoder_config = self.get_lora_config_for_text_encoders()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_no_lora = pipe(**inputs).images
|
||||
|
||||
pipe.transformer.add_adapter(transformer_lora_config)
|
||||
pipe.text_encoder.add_adapter(text_encoder_config)
|
||||
pipe.text_encoder_2.add_adapter(text_encoder_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two")
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_lora = pipe(**inputs).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
|
||||
"Lora + scale should change the output",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
|
||||
"Lora + 0 scale should lead to same result as no LoRA",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_transformer_fused(self):
|
||||
components = self.get_dummy_components()
|
||||
transformer_lora_config = self.get_lora_config_for_transformer()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_no_lora = pipe(**inputs).images
|
||||
|
||||
pipe.transformer.add_adapter(transformer_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
pipe.fuse_lora()
|
||||
# Fusing should still keep the LoRA layers
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
ouput_fused = pipe(**inputs).images
|
||||
self.assertFalse(
|
||||
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
|
||||
)
|
||||
|
||||
def test_simple_inference_with_transformer_fused_with_no_fusion(self):
|
||||
components = self.get_dummy_components()
|
||||
transformer_lora_config = self.get_lora_config_for_transformer()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_no_lora = pipe(**inputs).images
|
||||
|
||||
pipe.transformer.add_adapter(transformer_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
ouput_lora = pipe(**inputs).images
|
||||
|
||||
pipe.fuse_lora()
|
||||
# Fusing should still keep the LoRA layers
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
ouput_fused = pipe(**inputs).images
|
||||
self.assertFalse(
|
||||
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(ouput_fused, ouput_lora, atol=1e-3, rtol=1e-3),
|
||||
"Fused lora output should be changed when LoRA isn't fused but still effective.",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_transformer_fuse_unfuse(self):
|
||||
components = self.get_dummy_components()
|
||||
transformer_lora_config = self.get_lora_config_for_transformer()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_no_lora = pipe(**inputs).images
|
||||
|
||||
pipe.transformer.add_adapter(transformer_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
pipe.fuse_lora()
|
||||
# Fusing should still keep the LoRA layers
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
ouput_fused = pipe(**inputs).images
|
||||
self.assertFalse(
|
||||
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
|
||||
)
|
||||
|
||||
pipe.unfuse_lora()
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_unfused_lora = pipe(**inputs).images
|
||||
self.assertTrue(
|
||||
np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
|
||||
)
|
||||
|
||||
@require_torch_gpu
|
||||
def test_sd3_lora(self):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user