From bf406ea8869b2cfeb826ef32e9c96367ea2185e9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 9 Nov 2023 13:10:24 +0100 Subject: [PATCH] Correct consist dec (#5722) * uP * Update src/diffusers/models/consistency_decoder_vae.py * uP * uP --- src/diffusers/models/autoencoder_asym_kl.py | 1 + src/diffusers/models/autoencoder_tiny.py | 6 +- .../models/consistency_decoder_vae.py | 71 ++++++++++++++++++- tests/models/test_models_vae.py | 55 ++++++-------- 4 files changed, 96 insertions(+), 37 deletions(-) diff --git a/src/diffusers/models/autoencoder_asym_kl.py b/src/diffusers/models/autoencoder_asym_kl.py index d809912091..9f0fa62d34 100644 --- a/src/diffusers/models/autoencoder_asym_kl.py +++ b/src/diffusers/models/autoencoder_asym_kl.py @@ -138,6 +138,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin): def decode( self, z: torch.FloatTensor, + generator: Optional[torch.Generator] = None, image: Optional[torch.FloatTensor] = None, mask: Optional[torch.FloatTensor] = None, return_dict: bool = True, diff --git a/src/diffusers/models/autoencoder_tiny.py b/src/diffusers/models/autoencoder_tiny.py index 407b1906bb..15bd53ff99 100644 --- a/src/diffusers/models/autoencoder_tiny.py +++ b/src/diffusers/models/autoencoder_tiny.py @@ -14,7 +14,7 @@ from dataclasses import dataclass -from typing import Tuple, Union +from typing import Optional, Tuple, Union import torch @@ -307,7 +307,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin): return AutoencoderTinyOutput(latents=output) @apply_forward_hook - def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: + def decode( + self, x: torch.FloatTensor, generator: Optional[torch.Generator] = None, return_dict: bool = True + ) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: if self.use_slicing and x.shape[0] > 1: output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] output = torch.cat(output) diff --git a/src/diffusers/models/consistency_decoder_vae.py b/src/diffusers/models/consistency_decoder_vae.py index a2bcf1b8b7..9cb7c60bbb 100644 --- a/src/diffusers/models/consistency_decoder_vae.py +++ b/src/diffusers/models/consistency_decoder_vae.py @@ -68,11 +68,76 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): """ @register_to_config - def __init__(self, encoder_args, decoder_args, scaling_factor, block_out_channels, latent_channels): + def __init__( + self, + scaling_factor=0.18215, + latent_channels=4, + encoder_act_fn="silu", + encoder_block_out_channels=(128, 256, 512, 512), + encoder_double_z=True, + encoder_down_block_types=( + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ), + encoder_in_channels=3, + encoder_layers_per_block=2, + encoder_norm_num_groups=32, + encoder_out_channels=4, + decoder_add_attention=False, + decoder_block_out_channels=(320, 640, 1024, 1024), + decoder_down_block_types=( + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ), + decoder_downsample_padding=1, + decoder_in_channels=7, + decoder_layers_per_block=3, + decoder_norm_eps=1e-05, + decoder_norm_num_groups=32, + decoder_num_train_timesteps=1024, + decoder_out_channels=6, + decoder_resnet_time_scale_shift="scale_shift", + decoder_time_embedding_type="learned", + decoder_up_block_types=( + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ), + ): super().__init__() - self.encoder = Encoder(**encoder_args) - self.decoder_unet = UNet2DModel(**decoder_args) + self.encoder = Encoder( + act_fn=encoder_act_fn, + block_out_channels=encoder_block_out_channels, + double_z=encoder_double_z, + down_block_types=encoder_down_block_types, + in_channels=encoder_in_channels, + layers_per_block=encoder_layers_per_block, + norm_num_groups=encoder_norm_num_groups, + out_channels=encoder_out_channels, + ) + + self.decoder_unet = UNet2DModel( + add_attention=decoder_add_attention, + block_out_channels=decoder_block_out_channels, + down_block_types=decoder_down_block_types, + downsample_padding=decoder_downsample_padding, + in_channels=decoder_in_channels, + layers_per_block=decoder_layers_per_block, + norm_eps=decoder_norm_eps, + norm_num_groups=decoder_norm_num_groups, + num_train_timesteps=decoder_num_train_timesteps, + out_channels=decoder_out_channels, + resnet_time_scale_shift=decoder_resnet_time_scale_shift, + time_embedding_type=decoder_time_embedding_type, + up_block_types=decoder_up_block_types, + ) self.decoder_scheduler = ConsistencyDecoderScheduler() + self.register_to_config(block_out_channels=encoder_block_out_channels) self.register_buffer( "means", torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None], diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 1f5b847dd1..3b698624ff 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -303,39 +303,30 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase): @property def init_dict(self): return { - "encoder_args": { - "block_out_channels": [32, 64], - "in_channels": 3, - "out_channels": 4, - "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], - }, - "decoder_args": { - "act_fn": "silu", - "add_attention": False, - "block_out_channels": [32, 64], - "down_block_types": [ - "ResnetDownsampleBlock2D", - "ResnetDownsampleBlock2D", - ], - "downsample_padding": 1, - "downsample_type": "conv", - "dropout": 0.0, - "in_channels": 7, - "layers_per_block": 1, - "norm_eps": 1e-05, - "norm_num_groups": 32, - "num_train_timesteps": 1024, - "out_channels": 6, - "resnet_time_scale_shift": "scale_shift", - "time_embedding_type": "learned", - "up_block_types": [ - "ResnetUpsampleBlock2D", - "ResnetUpsampleBlock2D", - ], - "upsample_type": "conv", - }, + "encoder_block_out_channels": [32, 64], + "encoder_in_channels": 3, + "encoder_out_channels": 4, + "encoder_down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "decoder_add_attention": False, + "decoder_block_out_channels": [32, 64], + "decoder_down_block_types": [ + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + ], + "decoder_downsample_padding": 1, + "decoder_in_channels": 7, + "decoder_layers_per_block": 1, + "decoder_norm_eps": 1e-05, + "decoder_norm_num_groups": 32, + "decoder_num_train_timesteps": 1024, + "decoder_out_channels": 6, + "decoder_resnet_time_scale_shift": "scale_shift", + "decoder_time_embedding_type": "learned", + "decoder_up_block_types": [ + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], "scaling_factor": 1, - "block_out_channels": [32, 64], "latent_channels": 4, }