1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add test for loading model from pipeline module

This commit is contained in:
patil-suraj
2022-06-14 12:50:40 +02:00
parent d81b56ba5c
commit 147d8e0702

View File

@@ -19,9 +19,10 @@ import unittest
import torch
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
from diffusers import DDIM, DDPM, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_bddm import DiffWave
from diffusers.testing_utils import floats_tensor, slow, torch_device
@@ -212,3 +213,19 @@ class PipelineTesterMixin(unittest.TestCase):
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12)
bddm = BDDM(model, noise_scheduler)
# check if the library name for the diffwave moduel is set to pipeline module
self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm")
# check if we can save and load the pipeline
with tempfile.TemporaryDirectory() as tmpdirname:
bddm.save_pretrained(tmpdirname)
_ = BDDM.from_pretrained(tmpdirname)
# check if the same works using the DifusionPipeline class
_ = DiffusionPipeline.from_pretrained(tmpdirname)