diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index eb9debd18b..92612f9831 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -42,6 +42,18 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): class ModelTesterMixin(unittest.TestCase): + + @property + def dummy_input(self): + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes) + time_step = torch.tensor([10]) + + return (noise, time_step) + def test_from_pretrained_save_pretrained(self): config = UNetConfig(dim=8, dim_mults=(1, 2), resnet_block_groups=2) model = UNetModel(config) @@ -50,13 +62,16 @@ class ModelTesterMixin(unittest.TestCase): model.save_pretrained(tmpdirname) new_model = UNetModel.from_pretrained(tmpdirname) - batch_size = 1 - num_channels = 3 - sizes = (32, 32) - noise = floats_tensor((batch_size, num_channels) + sizes) - time_step = torch.tensor([10]) + dummy_input = self.dummy_input - image = model(noise, time_step) - new_image = new_model(noise, time_step) + image = model(*dummy_input) + new_image = new_model(*dummy_input) assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" + + def test_from_pretrained_hub(self): + model = UNetModel.from_pretrained("fusing/ddpm_dummy") + + image = model(*self.dummy_input) + + assert image is not None, "Make sure output is not None"