mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
up
This commit is contained in:
@@ -37,6 +37,7 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
@@ -49,6 +50,17 @@ from ..utils.peft_utils import _create_lora_config
|
||||
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
|
||||
|
||||
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False
|
||||
if is_torch_version(">=", "1.9.0"):
|
||||
if (
|
||||
is_peft_available()
|
||||
and is_peft_version(">=", "0.13.1")
|
||||
and is_transformers_available()
|
||||
and is_transformers_version(">", "4.45.2")
|
||||
):
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT_LORA = True
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
@@ -64,6 +76,10 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
|
||||
|
||||
TEXT_ENCODER_NAME = "text_encoder"
|
||||
UNET_NAME = "unet"
|
||||
TRANSFORMER_NAME = "transformer"
|
||||
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
"""
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..pipelines.stable_diffusion.lora_utils import StableDiffusionLoraLoaderMixin
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
@@ -127,468 +128,11 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
|
||||
return module_weight
|
||||
|
||||
|
||||
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
|
||||
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["unet", "text_encoder"]
|
||||
unet_name = UNET_NAME
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
|
||||
All kwargs are forwarded to `self.lora_state_dict`.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
||||
loaded.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
|
||||
loaded into `self.unet`.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
|
||||
dict is loaded into `self.text_encoder`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=getattr(self, self.text_encoder_name)
|
||||
if not hasattr(self, "text_encoder")
|
||||
else self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
metadata=metadata,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
> [!WARNING] > We support loading A1111 formatted LoRA checkpoints in a limited capacity. > > This function is
|
||||
experimental and might change in the future.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
return_lora_metadata (`bool`, *optional*, defaults to False):
|
||||
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# UNet and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
unet_config = kwargs.pop("unet_config", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
network_alphas = None
|
||||
# TODO: replace it with a method from `state_dict_utils`
|
||||
if all(
|
||||
(
|
||||
k.startswith("lora_te_")
|
||||
or k.startswith("lora_unet_")
|
||||
or k.startswith("lora_te1_")
|
||||
or k.startswith("lora_te2_")
|
||||
)
|
||||
for k in state_dict.keys()
|
||||
):
|
||||
# Map SDXL blocks correctly.
|
||||
if unet_config is not None:
|
||||
# use unet config to remap block numbers
|
||||
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(
|
||||
cls,
|
||||
state_dict,
|
||||
network_alphas,
|
||||
unet,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
metadata (`dict`):
|
||||
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
||||
from the state dict.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(
|
||||
cls,
|
||||
state_dict,
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The text encoder model to load the LoRA layers into.
|
||||
prefix (`str`):
|
||||
Expected prefix of the `text_encoder` in the `state_dict`.
|
||||
lora_scale (`float`):
|
||||
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
||||
lora layer.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
metadata (`dict`):
|
||||
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
||||
from the state dict.
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
network_alphas=network_alphas,
|
||||
lora_scale=lora_scale,
|
||||
text_encoder=text_encoder,
|
||||
prefix=prefix,
|
||||
text_encoder_name=cls.text_encoder_name,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
unet_lora_adapter_metadata=None,
|
||||
text_encoder_lora_adapter_metadata=None,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `unet`.
|
||||
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
||||
encoder LoRA state dict because it comes from 🤗 Transformers.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
unet_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the unet to be serialized with the state dict.
|
||||
text_encoder_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if unet_lora_layers:
|
||||
lora_layers[cls.unet_name] = unet_lora_layers
|
||||
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
|
||||
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["unet", "text_encoder"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
> [!WARNING] > This is an experimental API.
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
> [!WARNING] > This is an experimental API.
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
class StableDiffusionLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Import `StableDiffusionLoraLoaderMixin` from `diffusers.loaders` is deprecated and this will be removed in a future version. Please use `from diffusers.pipelines.stable_diffusion.lora_utils import StableDiffusionLoraLoaderMixin`, instead."
|
||||
deprecate("StableDiffusionLoraLoaderMixin", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
501
src/diffusers/pipelines/stable_diffusion/lora_utils.py
Normal file
501
src/diffusers/pipelines/stable_diffusion/lora_utils.py
Normal file
@@ -0,0 +1,501 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...loaders.lora_base import (
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT_LORA,
|
||||
LoraBaseMixin,
|
||||
_fetch_state_dict,
|
||||
_load_lora_into_text_encoder,
|
||||
)
|
||||
from ...loaders.lora_conversion_utils import (
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, is_peft_version, logging
|
||||
from ...utils.hub_utils import validate_hf_hub_args
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
TEXT_ENCODER_NAME = "text_encoder"
|
||||
UNET_NAME = "unet"
|
||||
|
||||
|
||||
class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and
|
||||
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["unet", "text_encoder"]
|
||||
unet_name = UNET_NAME
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
|
||||
All kwargs are forwarded to `self.lora_state_dict`.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
||||
loaded.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
|
||||
loaded into `self.unet`.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
|
||||
dict is loaded into `self.text_encoder`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
|
||||
in-place. This means that, instead of loading an additional adapter, this will take the existing
|
||||
adapter weights and replace them with the weights of the new adapter. This can be faster and more
|
||||
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
|
||||
torch.compile, loading the new adapter does not require recompilation of the model. When using
|
||||
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
|
||||
|
||||
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
|
||||
to call an additional method before loading the adapter:
|
||||
|
||||
```py
|
||||
pipeline = ... # load diffusers pipeline
|
||||
max_rank = ... # the highest rank among all LoRAs that you want to load
|
||||
# call *before* compiling and loading the LoRA adapter
|
||||
pipeline.enable_lora_hotswap(target_rank=max_rank)
|
||||
pipeline.load_lora_weights(file_name)
|
||||
# optionally compile the model now
|
||||
```
|
||||
|
||||
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
|
||||
limitations to this technique, which are documented here:
|
||||
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=getattr(self, self.text_encoder_name)
|
||||
if not hasattr(self, "text_encoder")
|
||||
else self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
metadata=metadata,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
> [!WARNING] > We support loading A1111 formatted LoRA checkpoints in a limited capacity. > > This function is
|
||||
experimental and might change in the future.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
return_lora_metadata (`bool`, *optional*, defaults to False):
|
||||
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# UNet and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
unet_config = kwargs.pop("unet_config", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
network_alphas = None
|
||||
# TODO: replace it with a method from `state_dict_utils`
|
||||
if all(
|
||||
(
|
||||
k.startswith("lora_te_")
|
||||
or k.startswith("lora_unet_")
|
||||
or k.startswith("lora_te1_")
|
||||
or k.startswith("lora_te2_")
|
||||
)
|
||||
for k in state_dict.keys()
|
||||
):
|
||||
# Map SDXL blocks correctly.
|
||||
if unet_config is not None:
|
||||
# use unet config to remap block numbers
|
||||
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
|
||||
return out
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(
|
||||
cls,
|
||||
state_dict,
|
||||
network_alphas,
|
||||
unet,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
metadata (`dict`):
|
||||
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
||||
from the state dict.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(
|
||||
cls,
|
||||
state_dict,
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The text encoder model to load the LoRA layers into.
|
||||
prefix (`str`):
|
||||
Expected prefix of the `text_encoder` in the `state_dict`.
|
||||
lora_scale (`float`):
|
||||
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
||||
lora layer.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
metadata (`dict`):
|
||||
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
|
||||
from the state dict.
|
||||
"""
|
||||
_load_lora_into_text_encoder(
|
||||
state_dict=state_dict,
|
||||
network_alphas=network_alphas,
|
||||
lora_scale=lora_scale,
|
||||
text_encoder=text_encoder,
|
||||
prefix=prefix,
|
||||
text_encoder_name=cls.text_encoder_name,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
unet_lora_adapter_metadata=None,
|
||||
text_encoder_lora_adapter_metadata=None,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `unet`.
|
||||
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
||||
encoder LoRA state dict because it comes from 🤗 Transformers.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
unet_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the unet to be serialized with the state dict.
|
||||
text_encoder_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if unet_lora_layers:
|
||||
lora_layers[cls.unet_name] = unet_lora_layers
|
||||
lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
|
||||
lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["unet", "text_encoder"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
> [!WARNING] > This is an experimental API.
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
> [!WARNING] > This is an experimental API.
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
@@ -20,12 +20,13 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from .lora_utils import StableDiffusionLoraLoaderMixin
|
||||
from .pipeline_output import StableDiffusionPipelineOutput
|
||||
from .pipeline_stable_diffusion_utils import SDMixin, rescale_noise_cfg, retrieve_timesteps
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
@@ -23,7 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTForDepthEstimation, DP
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
@@ -34,6 +34,7 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from .lora_utils import StableDiffusionLoraLoaderMixin
|
||||
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents
|
||||
|
||||
|
||||
|
||||
@@ -23,19 +23,14 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
deprecate,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .lora_utils import StableDiffusionLoraLoaderMixin
|
||||
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents, retrieve_timesteps
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
@@ -22,13 +22,14 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, TextualInversionLoaderMixin
|
||||
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .lora_utils import StableDiffusionLoraLoaderMixin
|
||||
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents, retrieve_timesteps
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
@@ -22,13 +22,14 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, IPAdapterMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .lora_utils import StableDiffusionLoraLoaderMixin
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
|
||||
@@ -21,13 +21,14 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_torch_xla_available, logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .lora_utils import StableDiffusionLoraLoaderMixin
|
||||
from .pipeline_stable_diffusion_utils import SDMixin
|
||||
|
||||
|
||||
|
||||
@@ -20,13 +20,14 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
|
||||
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
|
||||
from ...models.embeddings import get_timestep_embedding
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin
|
||||
from .lora_utils import StableDiffusionLoraLoaderMixin
|
||||
from .pipeline_stable_diffusion_utils import SDMixin
|
||||
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
|
||||
@@ -19,13 +19,14 @@ import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.embeddings import get_timestep_embedding
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin
|
||||
from .lora_utils import StableDiffusionLoraLoaderMixin
|
||||
from .pipeline_stable_diffusion_utils import SDMixin
|
||||
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user