mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add test for loading model from pipeline module
This commit is contained in:
@@ -19,9 +19,10 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
|
||||
from diffusers import DDIM, DDPM, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.pipeline_bddm import DiffWave
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
||||
|
||||
@@ -212,3 +213,19 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
def test_module_from_pipeline(self):
|
||||
model = DiffWave(num_res_layers=4)
|
||||
noise_scheduler = DDPMScheduler(timesteps=12)
|
||||
|
||||
bddm = BDDM(model, noise_scheduler)
|
||||
|
||||
# check if the library name for the diffwave moduel is set to pipeline module
|
||||
self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm")
|
||||
|
||||
# check if we can save and load the pipeline
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
bddm.save_pretrained(tmpdirname)
|
||||
_ = BDDM.from_pretrained(tmpdirname)
|
||||
# check if the same works using the DifusionPipeline class
|
||||
_ = DiffusionPipeline.from_pretrained(tmpdirname)
|
||||
Reference in New Issue
Block a user