mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
move vqmodel to models.autoencoders. (#8292)
move vqmodel to models.autoencoders.
This commit is contained in:
@@ -24,4 +24,4 @@ The abstract from the paper is:
|
||||
|
||||
## VQEncoderOutput
|
||||
|
||||
[[autodoc]] models.vq_model.VQEncoderOutput
|
||||
[[autodoc]] models.autoencoders.vq_model.VQEncoderOutput
|
||||
|
||||
@@ -31,6 +31,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
@@ -50,7 +51,6 @@ if is_torch_available():
|
||||
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
|
||||
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
|
||||
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
|
||||
_import_structure["vq_model"] = ["VQModel"]
|
||||
|
||||
if is_flax_available():
|
||||
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
|
||||
@@ -67,6 +67,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
VQModel,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
|
||||
@@ -92,7 +93,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
)
|
||||
from .vq_model import VQModel
|
||||
|
||||
if is_flax_available():
|
||||
from .controlnet_flax import FlaxControlNetModel
|
||||
|
||||
@@ -3,3 +3,4 @@ from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
from .vq_model import VQModel
|
||||
|
||||
182
src/diffusers/models/autoencoders/vq_model.py
Normal file
182
src/diffusers/models/autoencoders/vq_model.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class VQEncoderOutput(BaseOutput):
|
||||
"""
|
||||
Output of VQModel encoding method.
|
||||
|
||||
Args:
|
||||
latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The encoded output sample from the last layer of the model.
|
||||
"""
|
||||
|
||||
latents: torch.Tensor
|
||||
|
||||
|
||||
class VQModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VQ-VAE model for decoding latent representations.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
|
||||
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
|
||||
scaling_factor (`float`, *optional*, defaults to `0.18215`):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
||||
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
||||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
norm_type (`str`, *optional*, defaults to `"group"`):
|
||||
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 3,
|
||||
sample_size: int = 32,
|
||||
num_vq_embeddings: int = 256,
|
||||
norm_num_groups: int = 32,
|
||||
vq_embed_dim: Optional[int] = None,
|
||||
scaling_factor: float = 0.18215,
|
||||
norm_type: str = "group", # group, spatial
|
||||
mid_block_add_attention=True,
|
||||
lookup_from_codebook=False,
|
||||
force_upcast=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=False,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
|
||||
|
||||
self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
|
||||
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
|
||||
self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
norm_type=norm_type,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
|
||||
if not return_dict:
|
||||
return (h,)
|
||||
|
||||
return VQEncoderOutput(latents=h)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, commit_loss, _ = self.quantize(h)
|
||||
elif self.config.lookup_from_codebook:
|
||||
quant = self.quantize.get_codebook_entry(h, shape)
|
||||
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
||||
else:
|
||||
quant = h
|
||||
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
||||
quant2 = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
|
||||
|
||||
if not return_dict:
|
||||
return dec, commit_loss
|
||||
|
||||
return DecoderOutput(sample=dec, commit_loss=commit_loss)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
The [`VQModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vq_model.VQEncoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
|
||||
is returned.
|
||||
"""
|
||||
|
||||
h = self.encode(sample).latents
|
||||
dec = self.decode(h)
|
||||
|
||||
if not return_dict:
|
||||
return dec.sample, dec.commit_loss
|
||||
return dec
|
||||
@@ -11,172 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.accelerate_utils import apply_forward_hook
|
||||
from .autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
|
||||
from .modeling_utils import ModelMixin
|
||||
from ..utils import deprecate
|
||||
from .autoencoders.vq_model import VQEncoderOutput, VQModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class VQEncoderOutput(BaseOutput):
|
||||
"""
|
||||
Output of VQModel encoding method.
|
||||
|
||||
Args:
|
||||
latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The encoded output sample from the last layer of the model.
|
||||
"""
|
||||
|
||||
latents: torch.Tensor
|
||||
class VQEncoderOutput(VQEncoderOutput):
|
||||
deprecation_message = "Importing `VQEncoderOutput` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQEncoderOutput`, instead."
|
||||
deprecate("VQEncoderOutput", "0.31", deprecation_message)
|
||||
|
||||
|
||||
class VQModel(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VQ-VAE model for decoding latent representations.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of block output channels.
|
||||
layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
|
||||
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
|
||||
scaling_factor (`float`, *optional*, defaults to `0.18215`):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
||||
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
||||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
norm_type (`str`, *optional*, defaults to `"group"`):
|
||||
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 3,
|
||||
sample_size: int = 32,
|
||||
num_vq_embeddings: int = 256,
|
||||
norm_num_groups: int = 32,
|
||||
vq_embed_dim: Optional[int] = None,
|
||||
scaling_factor: float = 0.18215,
|
||||
norm_type: str = "group", # group, spatial
|
||||
mid_block_add_attention=True,
|
||||
lookup_from_codebook=False,
|
||||
force_upcast=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=False,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
|
||||
|
||||
self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
|
||||
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
|
||||
self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
norm_type=norm_type,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
|
||||
if not return_dict:
|
||||
return (h,)
|
||||
|
||||
return VQEncoderOutput(latents=h)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, commit_loss, _ = self.quantize(h)
|
||||
elif self.config.lookup_from_codebook:
|
||||
quant = self.quantize.get_codebook_entry(h, shape)
|
||||
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
||||
else:
|
||||
quant = h
|
||||
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
|
||||
quant2 = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
|
||||
|
||||
if not return_dict:
|
||||
return dec, commit_loss
|
||||
|
||||
return DecoderOutput(sample=dec, commit_loss=commit_loss)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
|
||||
r"""
|
||||
The [`VQModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vq_model.VQEncoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
|
||||
is returned.
|
||||
"""
|
||||
|
||||
h = self.encode(sample).latents
|
||||
dec = self.decode(h)
|
||||
|
||||
if not return_dict:
|
||||
return dec.sample, dec.commit_loss
|
||||
return dec
|
||||
class VQModel(VQModel):
|
||||
deprecation_message = "Importing `VQModel` from `diffusers.models.vq_model` is deprecated and this will be removed in a future version. Please use `from diffusers.models.autoencoders.vq_model import VQModel`, instead."
|
||||
deprecate("VQModel", "0.31", deprecation_message)
|
||||
|
||||
Reference in New Issue
Block a user