diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index cc8d095b53..392d5fb3fe 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -44,6 +44,8 @@ from typing_extensions import Self from .. import __version__ from ..configuration_utils import ConfigMixin +from ..models import AutoencoderKL +from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin from ..quantizers import PipelineQuantizationConfig from ..quantizers.bitsandbytes.utils import _check_bnb_status @@ -2171,13 +2173,136 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): class StableDiffusionMixin: - def __init__(self, *args, **kwargs): - deprecation_message = "`StableDiffusionMixin` from `diffusers.pipelines.pipeline_utils` is deprecated and this will be removed in a future version. Please use `StableDiffusionMixin` from `diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils`, instead." - deprecate("StableDiffusionMixin", "1.0.0", deprecation_message) + r""" + Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion) + """ - # To avoid circular imports and for being backwards-compatible. - from .stable_diffusion.pipeline_stable_diffusion_utils import ( - StableDiffusionMixin as ActualStableDiffusionMixin, + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, ) + self.vae.enable_slicing() - ActualStableDiffusionMixin.__init__(self, *args, **kwargs) + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497. + + The suffixes after the scaling factors represent the stages where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if not hasattr(self, "unet"): + raise ValueError("The pipeline must have `unet` for using FreeU.") + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) + + def disable_freeu(self): + """Disables the FreeU mechanism if enabled.""" + self.unet.disable_freeu() + + def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + > [!WARNING] > This API is 🧪 experimental. + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + """ + self.fusing_unet = False + self.fusing_vae = False + + if unet: + self.fusing_unet = True + self.unet.fuse_qkv_projections() + self.unet.set_attn_processor(FusedAttnProcessor2_0()) + + if vae: + if not isinstance(self.vae, AutoencoderKL): + raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") + + self.fusing_vae = True + self.vae.fuse_qkv_projections() + self.vae.set_attn_processor(FusedAttnProcessor2_0()) + + def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): + """Disable QKV projection fusion if enabled. + + > [!WARNING] > This API is 🧪 experimental. + + Args: + unet (`bool`, defaults to `True`): To apply fusion on the UNet. + vae (`bool`, defaults to `True`): To apply fusion on the VAE. + + """ + if unet: + if not self.fusing_unet: + logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") + else: + self.unet.unfuse_qkv_projections() + self.fusing_unet = False + + if vae: + if not self.fusing_vae: + logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") + else: + self.vae.unfuse_qkv_projections() + self.fusing_vae = False diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index a04006d29b..976c51be55 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -34,7 +34,7 @@ from ...utils import ( ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents +from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents if is_torch_xla_available(): @@ -72,7 +72,7 @@ def preprocess(image): class StableDiffusionDepth2ImgPipeline( - DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin + DiffusionPipeline, SDMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin ): r""" Pipeline for text-guided depth-based image-to-image generation using Stable Diffusion. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index 88958d6d32..a4cd7de090 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -25,9 +25,9 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput -from .pipeline_stable_diffusion_utils import StableDiffusionMixin +from .pipeline_stable_diffusion_utils import SDMixin from .safety_checker import StableDiffusionSafetyChecker @@ -41,7 +41,7 @@ else: logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMixin): +class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMixin, SDMixin): r""" Pipeline to generate image variations from an input image using Stable Diffusion. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 3b6f8358f7..b6aa1c2522 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -34,9 +34,9 @@ from ...utils import ( replace_example_docstring, ) from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput -from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents, retrieve_timesteps +from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents, retrieve_timesteps from .safety_checker import StableDiffusionSafetyChecker @@ -105,6 +105,7 @@ def preprocess(image): class StableDiffusionImg2ImgPipeline( DiffusionPipeline, StableDiffusionMixin, + SDMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index d30dd01c02..3595f837f9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -25,15 +25,11 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - deprecate, - is_torch_xla_available, - logging, -) +from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput -from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents, retrieve_timesteps +from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents, retrieve_timesteps from .safety_checker import StableDiffusionSafetyChecker @@ -50,6 +46,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name class StableDiffusionInpaintPipeline( DiffusionPipeline, StableDiffusionMixin, + SDMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index d943c95123..f476730657 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -26,9 +26,9 @@ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput -from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents +from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents from .safety_checker import StableDiffusionSafetyChecker @@ -69,6 +69,7 @@ def preprocess(image): class StableDiffusionInstructPix2PixPipeline( DiffusionPipeline, StableDiffusionMixin, + SDMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 29782f46a5..66f27235a0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -27,8 +27,8 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin +from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents if is_torch_xla_available(): @@ -68,7 +68,7 @@ def preprocess(image): return image -class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin): +class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMixin, SDMixin, FromSingleFileMixin): r""" Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 896f976623..a6c94745fb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -25,15 +25,11 @@ from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, Text from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers -from ...utils import ( - deprecate, - is_torch_xla_available, - logging, -) +from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput -from .pipeline_stable_diffusion_utils import StableDiffusionMixin +from .pipeline_stable_diffusion_utils import SDMixin if is_torch_xla_available(): @@ -75,6 +71,7 @@ def preprocess(image): class StableDiffusionUpscalePipeline( DiffusionPipeline, StableDiffusionMixin, + SDMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_utils.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_utils.py index c2afa9710d..7354effc0e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_utils.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_utils.py @@ -4,8 +4,6 @@ from typing import List, Optional, Union import torch from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL -from ...models.attention_processor import FusedAttnProcessor2_0 from ...models.lora import adjust_lora_scale_text_encoder from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers @@ -111,137 +109,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class StableDiffusionMixin: - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." - deprecate( - "enable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.enable_slicing() - - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." - deprecate( - "disable_vae_slicing", - "0.40.0", - depr_message, - ) - self.vae.disable_slicing() - - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." - deprecate( - "enable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.enable_tiling() - - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." - deprecate( - "disable_vae_tiling", - "0.40.0", - depr_message, - ) - self.vae.disable_tiling() - - def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): - r"""Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497. - - The suffixes after the scaling factors represent the stages where they are being applied. - - Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values - that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. - - Args: - s1 (`float`): - Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - s2 (`float`): - Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to - mitigate "oversmoothing effect" in the enhanced denoising process. - b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. - b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. - """ - if not hasattr(self, "unet"): - raise ValueError("The pipeline must have `unet` for using FreeU.") - self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) - - def disable_freeu(self): - """Disables the FreeU mechanism if enabled.""" - self.unet.disable_freeu() - - def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) - are fused. For cross-attention modules, key and value projection matrices are fused. - - > [!WARNING] > This API is 🧪 experimental. - - Args: - unet (`bool`, defaults to `True`): To apply fusion on the UNet. - vae (`bool`, defaults to `True`): To apply fusion on the VAE. - """ - self.fusing_unet = False - self.fusing_vae = False - - if unet: - self.fusing_unet = True - self.unet.fuse_qkv_projections() - self.unet.set_attn_processor(FusedAttnProcessor2_0()) - - if vae: - if not isinstance(self.vae, AutoencoderKL): - raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") - - self.fusing_vae = True - self.vae.fuse_qkv_projections() - self.vae.set_attn_processor(FusedAttnProcessor2_0()) - - def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): - """Disable QKV projection fusion if enabled. - - > [!WARNING] > This API is 🧪 experimental. - - Args: - unet (`bool`, defaults to `True`): To apply fusion on the UNet. - vae (`bool`, defaults to `True`): To apply fusion on the VAE. - - """ - if unet: - if not self.fusing_unet: - logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") - else: - self.unet.unfuse_qkv_projections() - self.fusing_unet = False - - if vae: - if not self.fusing_vae: - logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") - else: - self.vae.unfuse_qkv_projections() - self.fusing_vae = False - +class SDMixin: def _encode_prompt( self, prompt, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 62e67a1c9b..6d03cba35f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -24,14 +24,10 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - is_torch_xla_available, - logging, - replace_example_docstring, -) +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .pipeline_stable_diffusion_utils import StableDiffusionMixin +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin +from .pipeline_stable_diffusion_utils import SDMixin from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -64,7 +60,7 @@ EXAMPLE_DOC_STRING = """ class StableUnCLIPPipeline( - DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin + DiffusionPipeline, StableDiffusionMixin, SDMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin ): """ Pipeline for text-to-image generation using stable unCLIP. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index dd178af76a..f6b6a2b74f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -23,14 +23,10 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - is_torch_xla_available, - logging, - replace_example_docstring, -) +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor -from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput -from .pipeline_stable_diffusion_utils import StableDiffusionMixin +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin +from .pipeline_stable_diffusion_utils import SDMixin from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer @@ -74,7 +70,7 @@ EXAMPLE_DOC_STRING = """ class StableUnCLIPImg2ImgPipeline( - DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin + DiffusionPipeline, StableDiffusionMixin, SDMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin ): """ Pipeline for text-guided image-to-image generation using stable unCLIP.