From 71f24e36dee3e7407ebba250f052ca0008f19e50 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 13 Oct 2025 10:22:01 +0530 Subject: [PATCH] up --- .../models/autoencoders/autoencoder_dc.py | 25 +------------ .../models/autoencoders/autoencoder_kl.py | 33 +---------------- .../autoencoders/autoencoder_kl_allegro.py | 32 +--------------- .../autoencoders/autoencoder_kl_cogvideox.py | 25 +------------ .../autoencoders/autoencoder_kl_cosmos.py | 25 +------------ .../autoencoder_kl_hunyuan_video.py | 26 +------------ .../models/autoencoders/autoencoder_kl_ltx.py | 25 +------------ .../autoencoders/autoencoder_kl_magvit.py | 25 +------------ .../autoencoders/autoencoder_kl_mochi.py | 25 +------------ .../autoencoders/autoencoder_kl_qwenimage.py | 25 +------------ .../models/autoencoders/autoencoder_kl_wan.py | 25 +------------ .../autoencoders/autoencoder_oobleck.py | 17 +-------- .../models/autoencoders/autoencoder_tiny.py | 33 +---------------- .../autoencoders/consistency_decoder_vae.py | 37 +------------------ src/diffusers/models/autoencoders/vae.py | 35 ++++++++++++++++++ 15 files changed, 63 insertions(+), 350 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 783f22e97d..724ec3bb76 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -27,7 +27,7 @@ from ..attention_processor import SanaMultiscaleLinearAttention from ..modeling_utils import ModelMixin from ..normalization import RMSNorm, get_normalization from ..transformers.sana_transformer import GLUMBConv -from .vae import DecoderOutput, EncoderOutput +from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput class ResBlock(nn.Module): @@ -378,7 +378,7 @@ class Decoder(nn.Module): return hidden_states -class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in [SANA](https://huggingface.co/papers/2410.10629). @@ -536,27 +536,6 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin): self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio - def disable_tiling(self) -> None: - r""" - Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute - decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, height, width = x.shape diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index d823c2fb8b..1a72aa3cfe 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -32,10 +32,10 @@ from ..attention_processor import ( ) from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder +from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder -class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): +class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. @@ -138,35 +138,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 - def enable_tiling(self, use_tiling: bool = True): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.use_tiling = use_tiling - - def disable_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.enable_tiling(False) - - def enable_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py index c24b8f42ac..6756586460 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -28,6 +28,7 @@ from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from ..resnet import ResnetBlock2D from ..upsampling import Upsample2D +from .vae import AutoencoderMixin class AllegroTemporalConvLayer(nn.Module): @@ -673,7 +674,7 @@ class AllegroDecoder3D(nn.Module): return sample -class AutoencoderKLAllegro(ModelMixin, ConfigMixin): +class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in [Allegro](https://github.com/rhymes-ai/Allegro). @@ -795,35 +796,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin): sample_size - self.tile_overlap_w, ) - def enable_tiling(self) -> None: - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.use_tiling = True - - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: torch.Tensor) -> torch.Tensor: # TODO(aryan) # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index e0e9436e89..5096b725d0 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -29,7 +29,7 @@ from ..downsampling import CogVideoXDownsample3D from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from ..upsampling import CogVideoXUpsample3D -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -955,7 +955,7 @@ class CogVideoXDecoder3D(nn.Module): return hidden_states, new_conv_cache -class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [CogVideoX](https://github.com/THUDM/CogVideo). @@ -1124,27 +1124,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = x.shape diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py index 500e316ebc..b17522d1c4 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py @@ -24,7 +24,7 @@ from ...utils import get_logger from ...utils.accelerate_utils import apply_forward_hook from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, IdentityDistribution +from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution logger = get_logger(__name__) @@ -875,7 +875,7 @@ class CosmosDecoder3d(nn.Module): return hidden_states -class AutoencoderKLCosmos(ModelMixin, ConfigMixin): +class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin): r""" Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575). @@ -1031,27 +1031,6 @@ class AutoencoderKLCosmos(ModelMixin, ConfigMixin): self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: torch.Tensor) -> torch.Tensor: x = self.encoder(x) enc = self.quant_conv(x) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py index 7b0f9889a5..88b9bb507f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py @@ -18,7 +18,6 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.checkpoint from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging @@ -27,7 +26,7 @@ from ..activations import get_activation from ..attention_processor import Attention from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -625,7 +624,7 @@ class HunyuanVideoDecoder3D(nn.Module): return hidden_states -class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): +class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603). @@ -764,27 +763,6 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin): self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = x.shape diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py index 51c600a4e9..47f2081b7e 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py @@ -26,7 +26,7 @@ from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from ..normalization import RMSNorm -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution class LTXVideoCausalConv3d(nn.Module): @@ -1034,7 +1034,7 @@ class LTXVideoDecoder3d(nn.Module): return hidden_states -class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [LTX](https://huggingface.co/Lightricks/LTX-Video). @@ -1219,27 +1219,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin): self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _encode(self, x: torch.Tensor) -> torch.Tensor: batch_size, num_channels, num_frames, height, width = x.shape diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py b/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py index 43294a901f..97ca9d6692 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py @@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -663,7 +663,7 @@ class EasyAnimateDecoder(nn.Module): return hidden_states -class AutoencoderKLMagvit(ModelMixin, ConfigMixin): +class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991). @@ -805,27 +805,6 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin): self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - @apply_forward_hook def _encode( self, x: torch.Tensor, return_dict: bool = True diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py index 404d2f6d86..3ded9a0a54 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py @@ -27,7 +27,7 @@ from ..attention_processor import Attention, MochiVaeAttnProcessor2_0 from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -657,7 +657,7 @@ class MochiDecoder3D(nn.Module): return hidden_states, new_conv_cache -class AutoencoderKLMochi(ModelMixin, ConfigMixin): +class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in [Mochi 1 preview](https://github.com/genmoai/models). @@ -818,27 +818,6 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin): self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def _enable_framewise_encoding(self): r""" Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py index 87ac406592..844530d1f1 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py @@ -32,7 +32,7 @@ from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -664,7 +664,7 @@ class QwenImageDecoder3d(nn.Module): return x -class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. @@ -764,27 +764,6 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def clear_cache(self): def _count_conv3d(model): count = 0 diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index e6e58c1cce..cc3fd664da 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook from ..activations import get_activation from ..modeling_outputs import AutoencoderKLOutput from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DiagonalGaussianDistribution +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -952,7 +952,7 @@ def unpatchify(x, patch_size): return x -class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): r""" A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Introduced in [Wan 2.1]. @@ -1111,27 +1111,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin): self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_tiling = False - - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - def clear_cache(self): # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call self._conv_num = self._cached_conv_counts["decoder"] diff --git a/src/diffusers/models/autoencoders/autoencoder_oobleck.py b/src/diffusers/models/autoencoders/autoencoder_oobleck.py index a10b616b4e..d832645592 100644 --- a/src/diffusers/models/autoencoders/autoencoder_oobleck.py +++ b/src/diffusers/models/autoencoders/autoencoder_oobleck.py @@ -25,6 +25,7 @@ from ...utils import BaseOutput from ...utils.accelerate_utils import apply_forward_hook from ...utils.torch_utils import randn_tensor from ..modeling_utils import ModelMixin +from .vae import AutoencoderMixin class Snake1d(nn.Module): @@ -291,7 +292,7 @@ class OobleckDecoder(nn.Module): return hidden_state -class AutoencoderOobleck(ModelMixin, ConfigMixin): +class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin): r""" An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First introduced in Stable Audio. @@ -356,20 +357,6 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin): self.use_slicing = False - def enable_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True diff --git a/src/diffusers/models/autoencoders/autoencoder_tiny.py b/src/diffusers/models/autoencoders/autoencoder_tiny.py index 3e2b28606e..b9ac713d73 100644 --- a/src/diffusers/models/autoencoders/autoencoder_tiny.py +++ b/src/diffusers/models/autoencoders/autoencoder_tiny.py @@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput from ...utils.accelerate_utils import apply_forward_hook from ..modeling_utils import ModelMixin -from .vae import DecoderOutput, DecoderTiny, EncoderTiny +from .vae import AutoencoderMixin, DecoderOutput, DecoderTiny, EncoderTiny @dataclass @@ -38,7 +38,7 @@ class AutoencoderTinyOutput(BaseOutput): latents: torch.Tensor -class AutoencoderTiny(ModelMixin, ConfigMixin): +class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin): r""" A tiny distilled VAE model for encoding images into latents and decoding latent representations into images. @@ -162,35 +162,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): """[0, 1] -> raw latents""" return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) - def enable_slicing(self) -> None: - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - def disable_slicing(self) -> None: - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - - def enable_tiling(self, use_tiling: bool = True) -> None: - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.use_tiling = use_tiling - - def disable_tiling(self) -> None: - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.enable_tiling(False) - def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py index b3017a8780..0a6258fed3 100644 --- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py +++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py @@ -32,7 +32,7 @@ from ..attention_processor import ( ) from ..modeling_utils import ModelMixin from ..unets.unet_2d import UNet2DModel -from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder +from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder @dataclass @@ -49,7 +49,7 @@ class ConsistencyDecoderVAEOutput(BaseOutput): latent_dist: "DiagonalGaussianDistribution" -class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): +class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin): r""" The consistency decoder used with DALL-E 3. @@ -167,39 +167,6 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor = 0.25 - # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling - def enable_tiling(self, use_tiling: bool = True): - r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. - """ - self.use_tiling = use_tiling - - # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling - def disable_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.enable_tiling(False) - - # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing - def enable_slicing(self): - r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. - """ - self.use_slicing = True - - # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing - def disable_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing - decoding in one step. - """ - self.use_slicing = False - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 1d74d4f472..c8f29aeadf 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -894,3 +894,38 @@ class DecoderTiny(nn.Module): # scale image from [0, 1] to [-1, 1] to match diffusers convention return x.mul(2).sub(1) + + +class AutoencoderMixin: + def enable_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + if not hasattr(self, "use_tiling"): + raise NotImplementedError(f"Tiling doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_tiling = True + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + if not hasattr(self, "use_slicing"): + raise NotImplementedError(f"Tiling doesn't seem to be implemented for {self.__class__.__name__}.") + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False