diff --git a/docs/source/en/api/loaders.md b/docs/source/en/api/loaders.md index 98aaea0060..5c7c3ef660 100644 --- a/docs/source/en/api/loaders.md +++ b/docs/source/en/api/loaders.md @@ -28,6 +28,10 @@ Adapters (textual inversion, LoRA, hypernetworks) allow you to modify a diffusio [[autodoc]] loaders.TextualInversionLoaderMixin +## StableDiffusionXLLoraLoaderMixin + +[[autodoc]] loaders.StableDiffusionXLLoraLoaderMixin + ## LoraLoaderMixin [[autodoc]] loaders.LoraLoaderMixin diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 16eabb0077..51814a611a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -33,6 +33,7 @@ from .utils import ( _get_model_file, deprecate, is_accelerate_available, + is_accelerate_version, is_omegaconf_available, is_transformers_available, logging, @@ -2556,3 +2557,151 @@ class FromOriginalControlnetMixin: controlnet.to(torch_dtype=torch_dtype) return controlnet + + +class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): + """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL""" + + # Overrride to properly handle the loading and unloading of the additional text encoder. + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and + `self.text_encoder`. + + All kwargs are forwarded to `self.lora_state_dict`. + + See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. + + See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into + `self.unet`. + + See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded + into `self.text_encoder`. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + kwargs (`dict`, *optional*): + See [`~loaders.LoraLoaderMixin.lora_state_dict`]. + """ + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + + # Remove any existing hooks. + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module + else: + raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") + + is_model_cpu_offload = False + is_sequential_cpu_offload = False + recursive = False + for _, component in self.components.items(): + if isinstance(component, torch.nn.Module): + if hasattr(component, "_hf_hook"): + is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) + is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + recursive = is_sequential_cpu_offload + remove_hook_from_module(component, recurse=recursive) + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + ) + + # Offload back. + if is_model_cpu_offload: + self.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + self.enable_sequential_cpu_offload() + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + r""" + Save the LoRA parameters corresponding to the UNet and text encoder. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. Will be created if it doesn't exist. + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `unet`. + text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text + encoder LoRA state dict because it comes from 🤗 Transformers. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable + `DIFFUSERS_SAVE_MODE`. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. + """ + state_dict = {} + + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): + raise ValueError( + "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." + ) + + if unet_lora_layers: + state_dict.update(pack_weights(unet_lora_layers, "unet")) + + if text_encoder_lora_layers and text_encoder_2_lora_layers: + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + self.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 6595d8f456..cb4ec2f25b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -25,7 +24,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -36,8 +35,6 @@ from ...models.attention_processor import ( from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( - is_accelerate_available, - is_accelerate_version, is_invisible_watermark_available, logging, replace_example_docstring, @@ -128,7 +125,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin): +class StableDiffusionXLControlNetInpaintPipeline( + DiffusionPipeline, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -136,11 +135,11 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): @@ -308,7 +307,7 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -1510,108 +1509,3 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi return (image,) return StableDiffusionXLPipelineOutput(images=image) - - # Overrride to properly handle the loading and unloading of the additional text encoder. - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): - # We could have accessed the unet config from `lora_state_dict()` too. We pass - # it here explicitly to be able to tell that it's coming from an SDXL - # pipeline. - - # Remove any existing hooks. - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module - else: - raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") - - is_model_cpu_offload = False - is_sequential_cpu_offload = False - recursive = False - for _, component in self.components.items(): - if isinstance(component, torch.nn.Module): - if hasattr(component, "_hf_hook"): - is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recurse=recursive) - state_dict, network_alphas = self.lora_state_dict( - pretrained_model_name_or_path_or_dict, - unet_config=self.unet.config, - **kwargs, - ) - self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) - - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - ) - - text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} - if len(text_encoder_2_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_2_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder_2, - prefix="text_encoder_2", - lora_scale=self.lora_scale, - ) - - # Offload back. - if is_model_cpu_offload: - self.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - self.enable_sequential_cpu_offload() - - @classmethod - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights - def save_lora_weights( - self, - save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - state_dict = {} - - def pack_weights(layers, prefix): - layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict - - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." - ) - - if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) - - if text_encoder_lora_layers and text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) - state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - - self.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch - def _remove_text_encoder_monkey_patch(self): - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index e2f463329c..cbb78e509b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -14,7 +14,6 @@ import inspect -import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -26,7 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz from diffusers.utils.import_utils import is_invisible_watermark_available from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -37,8 +36,6 @@ from ...models.attention_processor import ( from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( - is_accelerate_available, - is_accelerate_version, logging, replace_example_docstring, ) @@ -103,7 +100,7 @@ EXAMPLE_DOC_STRING = """ class StableDiffusionXLControlNetPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. @@ -113,7 +110,7 @@ class StableDiffusionXLControlNetPipeline( The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - - [`loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files Args: @@ -283,7 +280,7 @@ class StableDiffusionXLControlNetPipeline( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -1176,108 +1173,3 @@ class StableDiffusionXLControlNetPipeline( return (image,) return StableDiffusionXLPipelineOutput(images=image) - - # Overrride to properly handle the loading and unloading of the additional text encoder. - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): - # We could have accessed the unet config from `lora_state_dict()` too. We pass - # it here explicitly to be able to tell that it's coming from an SDXL - # pipeline. - - # Remove any existing hooks. - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module - else: - raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") - - is_model_cpu_offload = False - is_sequential_cpu_offload = False - recursive = False - for _, component in self.components.items(): - if isinstance(component, torch.nn.Module): - if hasattr(component, "_hf_hook"): - is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recurse=recursive) - state_dict, network_alphas = self.lora_state_dict( - pretrained_model_name_or_path_or_dict, - unet_config=self.unet.config, - **kwargs, - ) - self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) - - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - ) - - text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} - if len(text_encoder_2_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_2_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder_2, - prefix="text_encoder_2", - lora_scale=self.lora_scale, - ) - - # Offload back. - if is_model_cpu_offload: - self.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - self.enable_sequential_cpu_offload() - - @classmethod - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights - def save_lora_weights( - self, - save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - state_dict = {} - - def pack_weights(layers, prefix): - layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict - - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." - ) - - if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) - - if text_encoder_lora_layers and text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) - state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - - self.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch - def _remove_text_encoder_monkey_patch(self): - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 8337b70445..6fe3d0c641 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -25,7 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz from diffusers.utils.import_utils import is_invisible_watermark_available from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -128,7 +128,9 @@ EXAMPLE_DOC_STRING = """ """ -class StableDiffusionXLControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class StableDiffusionXLControlNetImg2ImgPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin +): r""" Pipeline for image-to-image generation using Stable Diffusion XL with ControlNet guidance. @@ -137,7 +139,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(DiffusionPipeline, TextualInver In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] Args: vae ([`AutoencoderKL`]): @@ -316,7 +318,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(DiffusionPipeline, TextualInver # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 84fc9c7c57..40119a6087 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -22,7 +21,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz from ...image_processor import VaeImageProcessor from ...loaders import ( FromSingleFileMixin, - LoraLoaderMixin, + StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ) from ...models import AutoencoderKL, UNet2DConditionModel @@ -35,8 +34,6 @@ from ...models.attention_processor import ( from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( - is_accelerate_available, - is_accelerate_version, is_invisible_watermark_available, logging, replace_example_docstring, @@ -84,7 +81,9 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin): +class StableDiffusionXLPipeline( + DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -92,11 +91,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): @@ -257,7 +256,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -886,105 +885,3 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad return (image,) return StableDiffusionXLPipelineOutput(images=image) - - # Overrride to properly handle the loading and unloading of the additional text encoder. - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): - # We could have accessed the unet config from `lora_state_dict()` too. We pass - # it here explicitly to be able to tell that it's coming from an SDXL - # pipeline. - - # Remove any existing hooks. - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module - else: - raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") - - is_model_cpu_offload = False - is_sequential_cpu_offload = False - recursive = False - for _, component in self.components.items(): - if isinstance(component, torch.nn.Module): - if hasattr(component, "_hf_hook"): - is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recurse=recursive) - state_dict, network_alphas = self.lora_state_dict( - pretrained_model_name_or_path_or_dict, - unet_config=self.unet.config, - **kwargs, - ) - self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) - - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - ) - - text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} - if len(text_encoder_2_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_2_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder_2, - prefix="text_encoder_2", - lora_scale=self.lora_scale, - ) - - # Offload back. - if is_model_cpu_offload: - self.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - self.enable_sequential_cpu_offload() - - @classmethod - def save_lora_weights( - self, - save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - state_dict = {} - - def pack_weights(layers, prefix): - layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict - - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." - ) - - if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) - - if text_encoder_lora_layers and text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) - state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - - self.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - def _remove_text_encoder_monkey_patch(self): - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 4b66193f75..162fc828ff 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import PIL.Image @@ -21,7 +20,7 @@ import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -32,8 +31,6 @@ from ...models.attention_processor import ( from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( - is_accelerate_available, - is_accelerate_version, is_invisible_watermark_available, logging, replace_example_docstring, @@ -85,7 +82,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class StableDiffusionXLImg2ImgPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -94,11 +91,11 @@ class StableDiffusionXLImg2ImgPipeline( library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): @@ -266,7 +263,7 @@ class StableDiffusionXLImg2ImgPipeline( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -1036,108 +1033,3 @@ class StableDiffusionXLImg2ImgPipeline( return (image,) return StableDiffusionXLPipelineOutput(images=image) - - # Overrride to properly handle the loading and unloading of the additional text encoder. - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): - # We could have accessed the unet config from `lora_state_dict()` too. We pass - # it here explicitly to be able to tell that it's coming from an SDXL - # pipeline. - - # Remove any existing hooks. - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module - else: - raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") - - is_model_cpu_offload = False - is_sequential_cpu_offload = False - recursive = False - for _, component in self.components.items(): - if isinstance(component, torch.nn.Module): - if hasattr(component, "_hf_hook"): - is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recurse=recursive) - state_dict, network_alphas = self.lora_state_dict( - pretrained_model_name_or_path_or_dict, - unet_config=self.unet.config, - **kwargs, - ) - self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) - - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - ) - - text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} - if len(text_encoder_2_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_2_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder_2, - prefix="text_encoder_2", - lora_scale=self.lora_scale, - ) - - # Offload back. - if is_model_cpu_offload: - self.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - self.enable_sequential_cpu_offload() - - @classmethod - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights - def save_lora_weights( - self, - save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - state_dict = {} - - def pack_weights(layers, prefix): - layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict - - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." - ) - - if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) - - if text_encoder_lora_layers and text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) - state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - - self.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch - def _remove_text_encoder_monkey_patch(self): - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 55baada042..25753859c2 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -22,7 +21,7 @@ import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -34,8 +33,6 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, - is_accelerate_available, - is_accelerate_version, is_invisible_watermark_available, logging, replace_example_docstring, @@ -231,7 +228,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool class StableDiffusionXLInpaintPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion XL. @@ -240,11 +237,11 @@ class StableDiffusionXLInpaintPipeline( library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): @@ -415,7 +412,7 @@ class StableDiffusionXLInpaintPipeline( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -1355,108 +1352,3 @@ class StableDiffusionXLInpaintPipeline( return (image,) return StableDiffusionXLPipelineOutput(images=image) - - # Overrride to properly handle the loading and unloading of the additional text encoder. - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights - def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): - # We could have accessed the unet config from `lora_state_dict()` too. We pass - # it here explicitly to be able to tell that it's coming from an SDXL - # pipeline. - - # Remove any existing hooks. - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module - else: - raise ImportError("Offloading requires `accelerate v0.17.0` or higher.") - - is_model_cpu_offload = False - is_sequential_cpu_offload = False - recursive = False - for _, component in self.components.items(): - if isinstance(component, torch.nn.Module): - if hasattr(component, "_hf_hook"): - is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) - is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - recursive = is_sequential_cpu_offload - remove_hook_from_module(component, recurse=recursive) - state_dict, network_alphas = self.lora_state_dict( - pretrained_model_name_or_path_or_dict, - unet_config=self.unet.config, - **kwargs, - ) - self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) - - text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} - if len(text_encoder_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder, - prefix="text_encoder", - lora_scale=self.lora_scale, - ) - - text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} - if len(text_encoder_2_state_dict) > 0: - self.load_lora_into_text_encoder( - text_encoder_2_state_dict, - network_alphas=network_alphas, - text_encoder=self.text_encoder_2, - prefix="text_encoder_2", - lora_scale=self.lora_scale, - ) - - # Offload back. - if is_model_cpu_offload: - self.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - self.enable_sequential_cpu_offload() - - @classmethod - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights - def save_lora_weights( - self, - save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, - is_main_process: bool = True, - weight_name: str = None, - save_function: Callable = None, - safe_serialization: bool = True, - ): - state_dict = {} - - def pack_weights(layers, prefix): - layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers - layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} - return layers_state_dict - - if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): - raise ValueError( - "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." - ) - - if unet_lora_layers: - state_dict.update(pack_weights(unet_lora_layers, "unet")) - - if text_encoder_lora_layers and text_encoder_2_lora_layers: - state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) - state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) - - self.write_lora_layers( - state_dict=state_dict, - save_directory=save_directory, - is_main_process=is_main_process, - weight_name=weight_name, - save_function=save_function, - safe_serialization=safe_serialization, - ) - - # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch - def _remove_text_encoder_monkey_patch(self): - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) - self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 786231dd5c..0f951c78cb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -20,7 +20,7 @@ import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -93,7 +93,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class StableDiffusionXLInstructPix2PixPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin + DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin ): r""" Pipeline for pixel-level image editing by following text instructions. Based on Stable Diffusion XL. @@ -102,10 +102,10 @@ class StableDiffusionXLInstructPix2PixPipeline( library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] as well as the following saving methods: - - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] Args: vae ([`AutoencoderKL`]): @@ -268,7 +268,7 @@ class StableDiffusionXLInstructPix2PixPipeline( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale @@ -710,6 +710,14 @@ class StableDiffusionXLInstructPix2PixPipeline( For most cases, `target_size` should be set to the desired height and width of the generated image. If not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + aesthetic_score (`float`, *optional*, defaults to 6.0): + Used to simulate an aesthetic score of the generated image by influencing the positive text condition. + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + negative_aesthetic_score (`float`, *optional*, defaults to 2.5): + Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to + simulate an aesthetic score of the generated image by influencing the negative text condition. Examples: diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index d7441db707..6019d93fe0 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -23,7 +23,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -122,7 +122,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): class StableDiffusionXLAdapterPipeline( - DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin + DiffusionPipeline, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin ): r""" Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter @@ -280,7 +280,7 @@ class StableDiffusionXLAdapterPipeline( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale