mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add another test
This commit is contained in:
@@ -42,6 +42,18 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
|
||||
|
||||
class ModelTesterMixin(unittest.TestCase):
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes)
|
||||
time_step = torch.tensor([10])
|
||||
|
||||
return (noise, time_step)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config = UNetConfig(dim=8, dim_mults=(1, 2), resnet_block_groups=2)
|
||||
model = UNetModel(config)
|
||||
@@ -50,13 +62,16 @@ class ModelTesterMixin(unittest.TestCase):
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = UNetModel.from_pretrained(tmpdirname)
|
||||
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes)
|
||||
time_step = torch.tensor([10])
|
||||
dummy_input = self.dummy_input
|
||||
|
||||
image = model(noise, time_step)
|
||||
new_image = new_model(noise, time_step)
|
||||
image = model(*dummy_input)
|
||||
new_image = new_model(*dummy_input)
|
||||
|
||||
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model = UNetModel.from_pretrained("fusing/ddpm_dummy")
|
||||
|
||||
image = model(*self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
Reference in New Issue
Block a user