1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2025-11-27 13:25:30 +05:30
parent 15e3a0fe60
commit 8094f660cf
11 changed files with 163 additions and 182 deletions

View File

@@ -44,6 +44,8 @@ from typing_extensions import Self
from .. import __version__
from ..configuration_utils import ConfigMixin
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
from ..quantizers import PipelineQuantizationConfig
from ..quantizers.bitsandbytes.utils import _check_bnb_status
@@ -2171,13 +2173,136 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
class StableDiffusionMixin:
def __init__(self, *args, **kwargs):
deprecation_message = "`StableDiffusionMixin` from `diffusers.pipelines.pipeline_utils` is deprecated and this will be removed in a future version. Please use `StableDiffusionMixin` from `diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils`, instead."
deprecate("StableDiffusionMixin", "1.0.0", deprecation_message)
r"""
Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion)
"""
# To avoid circular imports and for being backwards-compatible.
from .stable_diffusion.pipeline_stable_diffusion_utils import (
StableDiffusionMixin as ActualStableDiffusionMixin,
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
deprecate(
"enable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.enable_slicing()
ActualStableDiffusionMixin.__init__(self, *args, **kwargs)
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
deprecate(
"disable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
deprecate(
"enable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
deprecate(
"disable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.disable_tiling()
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
> [!WARNING] > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
> [!WARNING] > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False

View File

@@ -34,7 +34,7 @@ from ...utils import (
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents
if is_torch_xla_available():
@@ -72,7 +72,7 @@ def preprocess(image):
class StableDiffusionDepth2ImgPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
DiffusionPipeline, SDMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
):
r"""
Pipeline for text-guided depth-based image-to-image generation using Stable Diffusion.

View File

@@ -25,9 +25,9 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .pipeline_stable_diffusion_utils import StableDiffusionMixin
from .pipeline_stable_diffusion_utils import SDMixin
from .safety_checker import StableDiffusionSafetyChecker
@@ -41,7 +41,7 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMixin):
class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMixin, SDMixin):
r"""
Pipeline to generate image variations from an input image using Stable Diffusion.

View File

@@ -34,9 +34,9 @@ from ...utils import (
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents, retrieve_timesteps
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents, retrieve_timesteps
from .safety_checker import StableDiffusionSafetyChecker
@@ -105,6 +105,7 @@ def preprocess(image):
class StableDiffusionImg2ImgPipeline(
DiffusionPipeline,
StableDiffusionMixin,
SDMixin,
TextualInversionLoaderMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,

View File

@@ -25,15 +25,11 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
is_torch_xla_available,
logging,
)
from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents, retrieve_timesteps
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents, retrieve_timesteps
from .safety_checker import StableDiffusionSafetyChecker
@@ -50,6 +46,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionInpaintPipeline(
DiffusionPipeline,
StableDiffusionMixin,
SDMixin,
TextualInversionLoaderMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,

View File

@@ -26,9 +26,9 @@ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents
from .safety_checker import StableDiffusionSafetyChecker
@@ -69,6 +69,7 @@ def preprocess(image):
class StableDiffusionInstructPix2PixPipeline(
DiffusionPipeline,
StableDiffusionMixin,
SDMixin,
TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin,
IPAdapterMixin,

View File

@@ -27,8 +27,8 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents
if is_torch_xla_available():
@@ -68,7 +68,7 @@ def preprocess(image):
return image
class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin):
class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMixin, SDMixin, FromSingleFileMixin):
r"""
Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2.

View File

@@ -25,15 +25,11 @@ from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, Text
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import (
deprecate,
is_torch_xla_available,
logging,
)
from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .pipeline_stable_diffusion_utils import StableDiffusionMixin
from .pipeline_stable_diffusion_utils import SDMixin
if is_torch_xla_available():
@@ -75,6 +71,7 @@ def preprocess(image):
class StableDiffusionUpscalePipeline(
DiffusionPipeline,
StableDiffusionMixin,
SDMixin,
TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin,
FromSingleFileMixin,

View File

@@ -4,8 +4,6 @@ from typing import List, Optional, Union
import torch
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL
from ...models.attention_processor import FusedAttnProcessor2_0
from ...models.lora import adjust_lora_scale_text_encoder
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
@@ -111,137 +109,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionMixin:
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
deprecate(
"enable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
deprecate(
"disable_vae_slicing",
"0.40.0",
depr_message,
)
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
deprecate(
"enable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
deprecate(
"disable_vae_tiling",
"0.40.0",
depr_message,
)
self.vae.disable_tiling()
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
> [!WARNING] > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
self.fusing_unet = False
self.fusing_vae = False
if unet:
self.fusing_unet = True
self.unet.fuse_qkv_projections()
self.unet.set_attn_processor(FusedAttnProcessor2_0())
if vae:
if not isinstance(self.vae, AutoencoderKL):
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
self.fusing_vae = True
self.vae.fuse_qkv_projections()
self.vae.set_attn_processor(FusedAttnProcessor2_0())
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
> [!WARNING] > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
"""
if unet:
if not self.fusing_unet:
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
else:
self.unet.unfuse_qkv_projections()
self.fusing_unet = False
if vae:
if not self.fusing_vae:
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
else:
self.vae.unfuse_qkv_projections()
self.fusing_vae = False
class SDMixin:
def _encode_prompt(
self,
prompt,

View File

@@ -24,14 +24,10 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix
from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .pipeline_stable_diffusion_utils import StableDiffusionMixin
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin
from .pipeline_stable_diffusion_utils import SDMixin
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
@@ -64,7 +60,7 @@ EXAMPLE_DOC_STRING = """
class StableUnCLIPPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
DiffusionPipeline, StableDiffusionMixin, SDMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
):
"""
Pipeline for text-to-image generation using stable unCLIP.

View File

@@ -23,14 +23,10 @@ from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMix
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.embeddings import get_timestep_embedding
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .pipeline_stable_diffusion_utils import StableDiffusionMixin
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin
from .pipeline_stable_diffusion_utils import SDMixin
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
@@ -74,7 +70,7 @@ EXAMPLE_DOC_STRING = """
class StableUnCLIPImg2ImgPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
DiffusionPipeline, StableDiffusionMixin, SDMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
):
"""
Pipeline for text-guided image-to-image generation using stable unCLIP.