From 333a8da678882930a3c981df09128c7d8ffd0212 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Wed, 29 Jun 2022 13:52:04 +0200 Subject: [PATCH] add tests for AutoencoderKL --- src/diffusers/models/vae.py | 6 +-- tests/test_modeling_utils.py | 73 ++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index bfc4a96e00..6bd9a07099 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -626,11 +626,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin): dec = self.decoder(z) return dec - def forward(self, input, sample_posterior=True): - posterior = self.encode(input) + def forward(self, x, sample_posterior=False): + posterior = self.encode(x) if sample_posterior: z = posterior.sample() else: z = posterior.mode() dec = self.decode(z) - return dec, posterior + return dec diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index f38e629d9e..c567fdbbdb 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -46,6 +46,7 @@ from diffusers import ( UNetLDMModel, UNetModel, VQModel, + AutoencoderKL, ) from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline @@ -883,6 +884,78 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) +class AutoEncoderKLTests(ModelTesterMixin, unittest.TestCase): + model_class = AutoencoderKL + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"x": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "ch": 64, + "ch_mult": (1,), + "embed_dim": 4, + "in_channels": 3, + "num_res_blocks": 1, + "out_ch": 3, + "resolution": 32, + "z_channels": 4, + "attn_resolutions": [] + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + def test_from_pretrained_hub(self): + model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + with torch.no_grad(): + output = model(image, sample_posterior=True) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-0.0814, -0.0229, -0.1320, -0.4123, -0.0366, -0.3473, 0.0438, -0.1662, + 0.1750]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + class PipelineTesterMixin(unittest.TestCase): def test_from_pretrained_save_pretrained(self): # 1. Load models