From 147d8e07029700e49a66991e6263fd2c39fd5fec Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Tue, 14 Jun 2022 12:50:40 +0200 Subject: [PATCH] add test for loading model from pipeline module --- tests/test_modeling_utils.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 6c119479fa..cacc356530 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -19,9 +19,10 @@ import unittest import torch -from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler +from diffusers import DDIM, DDPM, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.pipeline_bddm import DiffWave from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -212,3 +213,19 @@ class PipelineTesterMixin(unittest.TestCase): assert image.shape == (1, 3, 256, 256) expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + def test_module_from_pipeline(self): + model = DiffWave(num_res_layers=4) + noise_scheduler = DDPMScheduler(timesteps=12) + + bddm = BDDM(model, noise_scheduler) + + # check if the library name for the diffwave moduel is set to pipeline module + self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm") + + # check if we can save and load the pipeline + with tempfile.TemporaryDirectory() as tmpdirname: + bddm.save_pretrained(tmpdirname) + _ = BDDM.from_pretrained(tmpdirname) + # check if the same works using the DifusionPipeline class + _ = DiffusionPipeline.from_pretrained(tmpdirname) \ No newline at end of file