diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 2e0be1d334..f38e629d9e 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -22,6 +22,7 @@ import numpy as np import torch from diffusers import ( + AutoencoderKL, BDDMPipeline, DDIMPipeline, DDIMScheduler, @@ -44,6 +45,7 @@ from diffusers import ( UNetGradTTSModel, UNetLDMModel, UNetModel, + VQModel, ) from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline @@ -805,6 +807,82 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) +class VQModelTests(ModelTesterMixin, unittest.TestCase): + model_class = VQModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + + image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + + return {"x": image} + + @property + def input_shape(self): + return (3, 32, 32) + + @property + def output_shape(self): + return (3, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "ch": 64, + "out_ch": 3, + "num_res_blocks": 1, + "attn_resolutions": [], + "in_channels": 3, + "resolution": 32, + "z_channels": 3, + "n_embed": 256, + "embed_dim": 3, + "sane_index_shape": False, + "ch_mult": (1,), + "dropout": 0.0, + "double_z": False, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_training(self): + pass + + def test_from_pretrained_hub(self): + model, loading_info = VQModel.from_pretrained("fusing/vqgan-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = VQModel.from_pretrained("fusing/vqgan-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + image = torch.randn(1, model.config.in_channels, model.config.resolution, model.config.resolution) + with torch.no_grad(): + output = model(image) + + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-1.1321, 0.1056, 0.3505, -0.6461, -0.2014, 0.0419, -0.5763, -0.8462, + -0.4218]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + class PipelineTesterMixin(unittest.TestCase): def test_from_pretrained_save_pretrained(self): # 1. Load models