diff --git a/docs/source/en/api/models/vq.md b/docs/source/en/api/models/vq.md index a5ac6ba63e..fa0631e6fe 100644 --- a/docs/source/en/api/models/vq.md +++ b/docs/source/en/api/models/vq.md @@ -24,4 +24,4 @@ The abstract from the paper is: ## VQEncoderOutput -[[autodoc]] models.vq_model.VQEncoderOutput +[[autodoc]] models.autoencoders.vq_model.VQEncoderOutput diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 78b0efff92..6b29dd5f54 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -31,6 +31,7 @@ if is_torch_available(): _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] + _import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["controlnet"] = ["ControlNetModel"] _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"] @@ -50,7 +51,6 @@ if is_torch_available(): _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] _import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"] _import_structure["unets.uvit_2d"] = ["UVit2DModel"] - _import_structure["vq_model"] = ["VQModel"] if is_flax_available(): _import_structure["controlnet_flax"] = ["FlaxControlNetModel"] @@ -67,6 +67,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoencoderKLTemporalDecoder, AutoencoderTiny, ConsistencyDecoderVAE, + VQModel, ) from .controlnet import ControlNetModel from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel @@ -92,7 +93,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: UNetSpatioTemporalConditionModel, UVit2DModel, ) - from .vq_model import VQModel if is_flax_available(): from .controlnet_flax import FlaxControlNetModel diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 201a40ff17..5c47748d62 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -3,3 +3,4 @@ from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE +from .vq_model import VQModel diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py new file mode 100644 index 0000000000..2f9e75623e --- /dev/null +++ b/src/diffusers/models/autoencoders/vq_model.py @@ -0,0 +1,182 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook +from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer +from ..modeling_utils import ModelMixin + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): + The encoded output sample from the last layer of the model. + """ + + latents: torch.Tensor + + +class VQModel(ModelMixin, ConfigMixin): + r""" + A VQ-VAE model for decoding latent representations. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. + norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers. + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. + scaling_factor (`float`, *optional*, defaults to `0.18215`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + norm_type (`str`, *optional*, defaults to `"group"`): + Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. + """ + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 3, + sample_size: int = 32, + num_vq_embeddings: int = 256, + norm_num_groups: int = 32, + vq_embed_dim: Optional[int] = None, + scaling_factor: float = 0.18215, + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + lookup_from_codebook=False, + force_upcast=False, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=False, + mid_block_add_attention=mid_block_add_attention, + ) + + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels + + self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) + self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) + self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + norm_type=norm_type, + mid_block_add_attention=mid_block_add_attention, + ) + + @apply_forward_hook + def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput: + h = self.encoder(x) + h = self.quant_conv(h) + + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + @apply_forward_hook + def decode( + self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None + ) -> Union[DecoderOutput, torch.Tensor]: + # also go through quantization layer + if not force_not_quantize: + quant, commit_loss, _ = self.quantize(h) + elif self.config.lookup_from_codebook: + quant = self.quantize.get_codebook_entry(h, shape) + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) + else: + quant = h + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) + quant2 = self.post_quant_conv(quant) + dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) + + if not return_dict: + return dec, commit_loss + + return DecoderOutput(sample=dec, commit_loss=commit_loss) + + def forward( + self, sample: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]: + r""" + The [`VQModel`] forward method. + + Args: + sample (`torch.Tensor`): Input sample. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vq_model.VQEncoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` + is returned. + """ + + h = self.encode(sample).latents + dec = self.decode(h) + + if not return_dict: + return dec.sample, dec.commit_loss + return dec diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index cb32b1f407..71aeb09049 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -11,172 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn - -from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput -from ..utils.accelerate_utils import apply_forward_hook -from .autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer -from .modeling_utils import ModelMixin +from ..utils import deprecate +from .autoencoders.vq_model import VQEncoderOutput, VQModel -@dataclass -class VQEncoderOutput(BaseOutput): - """ - Output of VQModel encoding method. - - Args: - latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): - The encoded output sample from the last layer of the model. - """ - - latents: torch.Tensor +class VQEncoderOutput(VQEncoderOutput): + deprecation_message = "Importing `VQEncoderOutput` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQEncoderOutput`, instead." + deprecate("VQEncoderOutput", "0.31", deprecation_message) -class VQModel(ModelMixin, ConfigMixin): - r""" - A VQ-VAE model for decoding latent representations. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - in_channels (int, *optional*, defaults to 3): Number of channels in the input image. - out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): - Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): - Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): - Tuple of block output channels. - layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): Sample input size. - num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. - norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers. - vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. - scaling_factor (`float`, *optional*, defaults to `0.18215`): - The component-wise standard deviation of the trained latent space computed using the first batch of the - training set. This is used to scale the latent space to have unit variance when training the diffusion - model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the - diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 - / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image - Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - norm_type (`str`, *optional*, defaults to `"group"`): - Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. - """ - - @register_to_config - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), - up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 1, - act_fn: str = "silu", - latent_channels: int = 3, - sample_size: int = 32, - num_vq_embeddings: int = 256, - norm_num_groups: int = 32, - vq_embed_dim: Optional[int] = None, - scaling_factor: float = 0.18215, - norm_type: str = "group", # group, spatial - mid_block_add_attention=True, - lookup_from_codebook=False, - force_upcast=False, - ): - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=False, - mid_block_add_attention=mid_block_add_attention, - ) - - vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels - - self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) - self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) - self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) - - # pass init params to Decoder - self.decoder = Decoder( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - norm_type=norm_type, - mid_block_add_attention=mid_block_add_attention, - ) - - @apply_forward_hook - def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput: - h = self.encoder(x) - h = self.quant_conv(h) - - if not return_dict: - return (h,) - - return VQEncoderOutput(latents=h) - - @apply_forward_hook - def decode( - self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None - ) -> Union[DecoderOutput, torch.Tensor]: - # also go through quantization layer - if not force_not_quantize: - quant, commit_loss, _ = self.quantize(h) - elif self.config.lookup_from_codebook: - quant = self.quantize.get_codebook_entry(h, shape) - commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) - else: - quant = h - commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) - quant2 = self.post_quant_conv(quant) - dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) - - if not return_dict: - return dec, commit_loss - - return DecoderOutput(sample=dec, commit_loss=commit_loss) - - def forward( - self, sample: torch.Tensor, return_dict: bool = True - ) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]: - r""" - The [`VQModel`] forward method. - - Args: - sample (`torch.Tensor`): Input sample. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. - - Returns: - [`~models.vq_model.VQEncoderOutput`] or `tuple`: - If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` - is returned. - """ - - h = self.encode(sample).latents - dec = self.decode(h) - - if not return_dict: - return dec.sample, dec.commit_loss - return dec +class VQModel(VQModel): + deprecation_message = "Importing `VQModel` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQModel`, instead." + deprecate("VQModel", "0.31", deprecation_message)