From 044392d65fc916c2c3cfac7a992af9f319321e8a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 8 Dec 2025 13:36:29 +0530 Subject: [PATCH] up --- src/diffusers/loaders/lora_base.py | 16 + src/diffusers/loaders/lora_pipeline.py | 468 +--------------- .../pipelines/stable_diffusion/lora_utils.py | 501 ++++++++++++++++++ .../pipeline_stable_diffusion.py | 3 +- .../pipeline_stable_diffusion_depth2img.py | 3 +- .../pipeline_stable_diffusion_img2img.py | 11 +- .../pipeline_stable_diffusion_inpaint.py | 3 +- ...eline_stable_diffusion_instruct_pix2pix.py | 3 +- .../pipeline_stable_diffusion_upscale.py | 3 +- .../pipeline_stable_unclip.py | 3 +- .../pipeline_stable_unclip_img2img.py | 3 +- 11 files changed, 540 insertions(+), 477 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion/lora_utils.py diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 3d75a7d875..7a8a1d1a58 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -37,6 +37,7 @@ from ..utils import ( is_accelerate_available, is_peft_available, is_peft_version, + is_torch_version, is_transformers_available, is_transformers_version, logging, @@ -49,6 +50,17 @@ from ..utils.peft_utils import _create_lora_config from ..utils.state_dict_utils import _load_sft_state_dict_metadata +_LOW_CPU_MEM_USAGE_DEFAULT_LORA = False +if is_torch_version(">=", "1.9.0"): + if ( + is_peft_available() + and is_peft_version(">=", "0.13.1") + and is_transformers_available() + and is_transformers_version(">", "4.45.2") + ): + _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True + + if is_transformers_available(): from transformers import PreTrainedModel @@ -64,6 +76,10 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata" +TEXT_ENCODER_NAME = "text_encoder" +UNET_NAME = "unet" +TRANSFORMER_NAME = "transformer" + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): """ diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index bcbe54649f..7f40d9442b 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -18,6 +18,7 @@ from typing import Callable, Dict, List, Optional, Union import torch from huggingface_hub.utils import validate_hf_hub_args +from ..pipelines.stable_diffusion.lora_utils import StableDiffusionLoraLoaderMixin from ..utils import ( USE_PEFT_BACKEND, deprecate, @@ -127,468 +128,11 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module): return module_weight -class StableDiffusionLoraLoaderMixin(LoraBaseMixin): - r""" - Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and - [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). - """ - - _lora_loadable_modules = ["unet", "text_encoder"] - unet_name = UNET_NAME - text_encoder_name = TEXT_ENCODER_NAME - - def load_lora_weights( - self, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - adapter_name: Optional[str] = None, - hotswap: bool = False, - **kwargs, - ): - """Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and - `self.text_encoder`. - - All kwargs are forwarded to `self.lora_state_dict`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is - loaded. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is - loaded into `self.unet`. - - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state - dict is loaded into `self.text_encoder`. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter - in-place. This means that, instead of loading an additional adapter, this will take the existing - adapter weights and replace them with the weights of the new adapter. This can be faster and more - memory efficient. However, the main advantage of hotswapping is that when the model is compiled with - torch.compile, loading the new adapter does not require recompilation of the model. When using - hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. - - If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need - to call an additional method before loading the adapter: - - ```py - pipeline = ... # load diffusers pipeline - max_rank = ... # the highest rank among all LoRAs that you want to load - # call *before* compiling and loading the LoRA adapter - pipeline.enable_lora_hotswap(target_rank=max_rank) - pipeline.load_lora_weights(file_name) - # optionally compile the model now - ``` - - Note that hotswapping adapters of the text encoder is not yet supported. There are some further - limitations to this technique, which are documented here: - https://huggingface.co/docs/peft/main/en/package_reference/hotswap - kwargs (`dict`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # if a dict is passed, copy it instead of modifying it inplace - if isinstance(pretrained_model_name_or_path_or_dict, dict): - pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() - - # First, ensure that the checkpoint is a compatible one and can be successfully loaded. - kwargs["return_lora_metadata"] = True - state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - - is_correct_format = all("lora" in key for key in state_dict.keys()) - if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") - - self.load_lora_into_unet( - state_dict, - network_alphas=network_alphas, - unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=self, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - self.load_lora_into_text_encoder( - state_dict, - network_alphas=network_alphas, - text_encoder=getattr(self, self.text_encoder_name) - if not hasattr(self, "text_encoder") - else self.text_encoder, - lora_scale=self.lora_scale, - adapter_name=adapter_name, - _pipeline=self, - metadata=metadata, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - @validate_hf_hub_args - def lora_state_dict( - cls, - pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], - **kwargs, - ): - r""" - Return state dict for lora weights and the network alphas. - - > [!WARNING] > We support loading A1111 formatted LoRA checkpoints in a limited capacity. > > This function is - experimental and might change in the future. - - Parameters: - pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): - Can be either: - - - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on - the Hub. - - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved - with [`ModelMixin.save_pretrained`]. - - A [torch state - dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). - - cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory where a downloaded pretrained model configuration is cached if the standard cache - is not used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - local_files_only (`bool`, *optional*, defaults to `False`): - Whether to only load local model weights and configuration files or not. If set to `True`, the model - won't be downloaded from the Hub. - token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from - `diffusers-cli login` (stored in `~/.huggingface`) is used. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier - allowed by Git. - subfolder (`str`, *optional*, defaults to `""`): - The subfolder location of a model file within a larger model repository on the Hub or locally. - weight_name (`str`, *optional*, defaults to None): - Name of the serialized state dict file. - return_lora_metadata (`bool`, *optional*, defaults to False): - When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. - """ - # Load the main state dict first which has the LoRA layers for either of - # UNet and text encoder or both. - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - weight_name = kwargs.pop("weight_name", None) - unet_config = kwargs.pop("unet_config", None) - use_safetensors = kwargs.pop("use_safetensors", None) - return_lora_metadata = kwargs.pop("return_lora_metadata", False) - - allow_pickle = False - if use_safetensors is None: - use_safetensors = True - allow_pickle = True - - user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} - - state_dict, metadata = _fetch_state_dict( - pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, - weight_name=weight_name, - use_safetensors=use_safetensors, - local_files_only=local_files_only, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - allow_pickle=allow_pickle, - ) - is_dora_scale_present = any("dora_scale" in k for k in state_dict) - if is_dora_scale_present: - warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." - logger.warning(warn_msg) - state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} - - network_alphas = None - # TODO: replace it with a method from `state_dict_utils` - if all( - ( - k.startswith("lora_te_") - or k.startswith("lora_unet_") - or k.startswith("lora_te1_") - or k.startswith("lora_te2_") - ) - for k in state_dict.keys() - ): - # Map SDXL blocks correctly. - if unet_config is not None: - # use unet config to remap block numbers - state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) - state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) - - out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas) - return out - - @classmethod - def load_lora_into_unet( - cls, - state_dict, - network_alphas, - unet, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, - ): - """ - This will load the LoRA layers specified in `state_dict` into `unet`. - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The keys can either be indexed directly - into the unet or prefixed with an additional `unet` which can be used to distinguish between text - encoder lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - unet (`UNet2DConditionModel`): - The UNet model to load the LoRA layers into. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. - """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as - # their prefixes. - logger.info(f"Loading {cls.unet_name}.") - unet.load_lora_adapter( - state_dict, - prefix=cls.unet_name, - network_alphas=network_alphas, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - def load_lora_into_text_encoder( - cls, - state_dict, - network_alphas, - text_encoder, - prefix=None, - lora_scale=1.0, - adapter_name=None, - _pipeline=None, - low_cpu_mem_usage=False, - hotswap: bool = False, - metadata=None, - ): - """ - This will load the LoRA layers specified in `state_dict` into `text_encoder` - - Parameters: - state_dict (`dict`): - A standard state dict containing the lora layer parameters. The key should be prefixed with an - additional `text_encoder` to distinguish between unet lora layers. - network_alphas (`Dict[str, float]`): - The value of the network alpha used for stable learning and preventing underflow. This value has the - same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this - link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). - text_encoder (`CLIPTextModel`): - The text encoder model to load the LoRA layers into. - prefix (`str`): - Expected prefix of the `text_encoder` in the `state_dict`. - lora_scale (`float`): - How much to scale the output of the lora linear layer before it is added with the output of the regular - lora layer. - adapter_name (`str`, *optional*): - Adapter name to be used for referencing the loaded adapter model. If not specified, it will use - `default_{i}` where i is the total number of adapters being loaded. - low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random - weights. - hotswap (`bool`, *optional*): - See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. - metadata (`dict`): - Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived - from the state dict. - """ - _load_lora_into_text_encoder( - state_dict=state_dict, - network_alphas=network_alphas, - lora_scale=lora_scale, - text_encoder=text_encoder, - prefix=prefix, - text_encoder_name=cls.text_encoder_name, - adapter_name=adapter_name, - metadata=metadata, - _pipeline=_pipeline, - low_cpu_mem_usage=low_cpu_mem_usage, - hotswap=hotswap, - ) - - @classmethod - def save_lora_weights( - cls, - save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - unet_lora_adapter_metadata=None, - text_encoder_lora_adapter_metadata=None, - ): - r""" - Save the LoRA parameters corresponding to the UNet and text encoder. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to save LoRA parameters to. Will be created if it doesn't exist. - unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `unet`. - text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): - State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text - encoder LoRA state dict because it comes from 🤗 Transformers. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful during distributed training and you - need to call this function on all processes. In this case, set `is_main_process=True` only on the main - process to avoid race conditions. - save_function (`Callable`): - The function to use to save the state dictionary. Useful during distributed training when you need to - replace `torch.save` with another method. Can be configured with the environment variable - `DIFFUSERS_SAVE_MODE`. - safe_serialization (`bool`, *optional*, defaults to `True`): - Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. - unet_lora_adapter_metadata: - LoRA adapter metadata associated with the unet to be serialized with the state dict. - text_encoder_lora_adapter_metadata: - LoRA adapter metadata associated with the text encoder to be serialized with the state dict. - """ - lora_layers = {} - lora_metadata = {} - - if unet_lora_layers: - lora_layers[cls.unet_name] = unet_lora_layers - lora_metadata[cls.unet_name] = unet_lora_adapter_metadata - - if text_encoder_lora_layers: - lora_layers[cls.text_encoder_name] = text_encoder_lora_layers - lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata - - if not lora_layers: - raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.") - - cls._save_lora_weights( - save_directory=save_directory, - lora_layers=lora_layers, - lora_metadata=lora_metadata, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - def fuse_lora( - self, - components: List[str] = ["unet", "text_encoder"], - lora_scale: float = 1.0, - safe_fusing: bool = False, - adapter_names: Optional[List[str]] = None, - **kwargs, - ): - r""" - Fuses the LoRA parameters into the original parameters of the corresponding blocks. - - > [!WARNING] > This is an experimental API. - - Args: - components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. - lora_scale (`float`, defaults to 1.0): - Controls how much to influence the outputs with the LoRA parameters. - safe_fusing (`bool`, defaults to `False`): - Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. - adapter_names (`List[str]`, *optional*): - Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. - - Example: - - ```py - from diffusers import DiffusionPipeline - import torch - - pipeline = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") - pipeline.fuse_lora(lora_scale=0.7) - ``` - """ - super().fuse_lora( - components=components, - lora_scale=lora_scale, - safe_fusing=safe_fusing, - adapter_names=adapter_names, - **kwargs, - ) - - def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs): - r""" - Reverses the effect of - [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). - - > [!WARNING] > This is an experimental API. - - Args: - components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. - unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. - """ - super().unfuse_lora(components=components, **kwargs) +class StableDiffusionLoraLoaderMixin(StableDiffusionLoraLoaderMixin): + def __init__(self, *args, **kwargs): + deprecation_message = "Import `StableDiffusionLoraLoaderMixin` from `diffusers.loaders` is deprecated and this will be removed in a future version. Please use `from diffusers.pipelines.stable_diffusion.lora_utils import StableDiffusionLoraLoaderMixin`, instead." + deprecate("StableDiffusionLoraLoaderMixin", "1.0.0", deprecation_message) + super().__init__(*args, **kwargs) class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): diff --git a/src/diffusers/pipelines/stable_diffusion/lora_utils.py b/src/diffusers/pipelines/stable_diffusion/lora_utils.py new file mode 100644 index 0000000000..eda3a352ac --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/lora_utils.py @@ -0,0 +1,501 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, Dict, List, Optional, Union + +import torch + +from ...loaders.lora_base import ( + _LOW_CPU_MEM_USAGE_DEFAULT_LORA, + LoraBaseMixin, + _fetch_state_dict, + _load_lora_into_text_encoder, +) +from ...loaders.lora_conversion_utils import ( + _convert_non_diffusers_lora_to_diffusers, + _maybe_map_sgm_blocks_to_diffusers, +) +from ...utils import USE_PEFT_BACKEND, is_peft_version, logging +from ...utils.hub_utils import validate_hf_hub_args + + +logger = logging.get_logger(__name__) + +TEXT_ENCODER_NAME = "text_encoder" +UNET_NAME = "unet" + + +class StableDiffusionLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into Stable Diffusion [`UNet2DConditionModel`] and + [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). + """ + + _lora_loadable_modules = ["unet", "text_encoder"] + unet_name = UNET_NAME + text_encoder_name = TEXT_ENCODER_NAME + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is + loaded. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is + loaded into `self.unet`. + + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state + dict is loaded into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter + in-place. This means that, instead of loading an additional adapter, this will take the existing + adapter weights and replace them with the weights of the new adapter. This can be faster and more + memory efficient. However, the main advantage of hotswapping is that when the model is compiled with + torch.compile, loading the new adapter does not require recompilation of the model. When using + hotswapping, the passed `adapter_name` should be the name of an already loaded adapter. + + If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need + to call an additional method before loading the adapter: + + ```py + pipeline = ... # load diffusers pipeline + max_rank = ... # the highest rank among all LoRAs that you want to load + # call *before* compiling and loading the LoRA adapter + pipeline.enable_lora_hotswap(target_rank=max_rank) + pipeline.load_lora_weights(file_name) + # optionally compile the model now + ``` + + Note that hotswapping adapters of the text encoder is not yet supported. There are some further + limitations to this technique, which are documented here: + https://huggingface.co/docs/peft/main/en/package_reference/hotswap + kwargs (`dict`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + self.load_lora_into_unet( + state_dict, + network_alphas=network_alphas, + unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + self.load_lora_into_text_encoder( + state_dict, + network_alphas=network_alphas, + text_encoder=getattr(self, self.text_encoder_name) + if not hasattr(self, "text_encoder") + else self.text_encoder, + lora_scale=self.lora_scale, + adapter_name=adapter_name, + _pipeline=self, + metadata=metadata, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + > [!WARNING] > We support loading A1111 formatted LoRA checkpoints in a limited capacity. > > This function is + experimental and might change in the future. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository on the Hub or locally. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. + """ + # Load the main state dict first which has the LoRA layers for either of + # UNet and text encoder or both. + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + network_alphas = None + # TODO: replace it with a method from `state_dict_utils` + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) + + out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas) + return out + + @classmethod + def load_lora_into_unet( + cls, + state_dict, + network_alphas, + unet, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `unet`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as + # their prefixes. + logger.info(f"Loading {cls.unet_name}.") + unet.load_lora_adapter( + state_dict, + prefix=cls.unet_name, + network_alphas=network_alphas, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_text_encoder( + cls, + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + This will load the LoRA layers specified in `state_dict` into `text_encoder` + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The key should be prefixed with an + additional `text_encoder` to distinguish between unet lora layers. + network_alphas (`Dict[str, float]`): + The value of the network alpha used for stable learning and preventing underflow. This value has the + same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this + link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). + text_encoder (`CLIPTextModel`): + The text encoder model to load the LoRA layers into. + prefix (`str`): + Expected prefix of the `text_encoder` in the `state_dict`. + lora_scale (`float`): + How much to scale the output of the lora linear layer before it is added with the output of the regular + lora layer. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + hotswap (`bool`, *optional*): + See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived + from the state dict. + """ + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + unet_lora_adapter_metadata=None, + text_encoder_lora_adapter_metadata=None, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `unet`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + unet_lora_adapter_metadata: + LoRA adapter metadata associated with the unet to be serialized with the state dict. + text_encoder_lora_adapter_metadata: + LoRA adapter metadata associated with the text encoder to be serialized with the state dict. + """ + lora_layers = {} + lora_metadata = {} + + if unet_lora_layers: + lora_layers[cls.unet_name] = unet_lora_layers + lora_metadata[cls.unet_name] = unet_lora_adapter_metadata + + if text_encoder_lora_layers: + lora_layers[cls.text_encoder_name] = text_encoder_lora_layers + lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["unet", "text_encoder"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + > [!WARNING] > This is an experimental API. + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: List[str] = ["unet", "text_encoder"], **kwargs): + r""" + Reverses the effect of + [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). + + > [!WARNING] > This is an experimental API. + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. + unfuse_text_encoder (`bool`, defaults to `True`): + Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the + LoRA parameters then it won't have any effect. + """ + super().unfuse_lora(components=components, **kwargs) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index e2745005b2..62a70583f0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -20,12 +20,13 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .lora_utils import StableDiffusionLoraLoaderMixin from .pipeline_output import StableDiffusionPipelineOutput from .pipeline_stable_diffusion_utils import SDMixin, rescale_noise_cfg, retrieve_timesteps from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 976c51be55..cfaf03bdc8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -23,7 +23,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTForDepthEstimation, DP from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -34,6 +34,7 @@ from ...utils import ( ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from .lora_utils import StableDiffusionLoraLoaderMixin from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index b6aa1c2522..ac7aca3a06 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -23,19 +23,14 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - PIL_INTERPOLATION, - deprecate, - is_torch_xla_available, - logging, - replace_example_docstring, -) +from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput +from .lora_utils import StableDiffusionLoraLoaderMixin from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents, retrieve_timesteps from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 3595f837f9..ad1bd5e307 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -22,13 +22,14 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...configuration_utils import FrozenDict from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, TextualInversionLoaderMixin from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput +from .lora_utils import StableDiffusionLoraLoaderMixin from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents, retrieve_timesteps from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index a6c98e9f25..80f59d6e1d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -22,13 +22,14 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, IPAdapterMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput +from .lora_utils import StableDiffusionLoraLoaderMixin from .safety_checker import StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index aa58b8534c..86f1f213ad 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -21,13 +21,14 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput +from .lora_utils import StableDiffusionLoraLoaderMixin from .pipeline_stable_diffusion_utils import SDMixin diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index c9f88e9d57..a964789667 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -20,13 +20,14 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz from transformers.models.clip.modeling_clip import CLIPTextModelOutput from ...image_processor import VaeImageProcessor -from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin +from .lora_utils import StableDiffusionLoraLoaderMixin from .pipeline_stable_diffusion_utils import SDMixin from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index f6b6a2b74f..372edd54a2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -19,13 +19,14 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import VaeImageProcessor -from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin +from .lora_utils import StableDiffusionLoraLoaderMixin from .pipeline_stable_diffusion_utils import SDMixin from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer