1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add sketch of tests (need more changes)

This commit is contained in:
Nathan Lambert
2022-11-29 17:05:51 -08:00
parent 01b0b868a4
commit bbd9043be4
2 changed files with 108 additions and 1 deletions

View File

@@ -132,7 +132,7 @@ class ALDScheduler(SchedulerMixin, ConfigMixin):
self.set_timesteps(num_inference_steps)
self.sigmas = torch.tensor(
torch.exp(torch.linspace(torch.log(sigma_min), torch.log(sigma_max), num_inference_steps)),
np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)),
dtype=torch.float32,
)

View File

@@ -20,6 +20,7 @@ import numpy as np
import torch
from diffusers import (
ALDScheduler,
DDIMScheduler,
DDPMScheduler,
IPNDMScheduler,
@@ -875,6 +876,112 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
class ALDSchedulerTest(unittest.TestCase):
# TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration), similar to SDE VE
scheduler_classes = (ALDScheduler,)
forward_default_kwargs = ()
@property
def dummy_sample(self):
batch_size = 4
num_channels = 3
height = 8
width = 8
sample = torch.rand((batch_size, num_channels, height, width))
return sample
@property
def dummy_sample_deter(self):
batch_size = 4
num_channels = 3
height = 8
width = 8
num_elems = batch_size * num_channels * height * width
sample = torch.arange(num_elems)
sample = sample.reshape(num_channels, height, width, batch_size)
sample = sample / num_elems
sample = sample.permute(3, 0, 1, 2)
return sample
def dummy_model(self):
def model(sample, t, *args):
return sample * t / (t + 1)
return model
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 100,
"sigma_min": 0.01,
"sigma_max": 1.0,
"step_lr": 0.00002,
}
config.update(**kwargs)
return config
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
).prev_sample
new_output = new_scheduler.step(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
).prev_sample
new_output = new_scheduler.step(
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_timesteps(self):
for timesteps in [10, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_sigmas(self):
for sigma_min, sigma_max in zip([0.0001, 0.001, 0.01], [1, 1, 1]):
self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max)
def test_time_indices(self):
for t in [0.1, 0.5, 0.75]:
self.check_over_forward(time_step=t)
class LMSDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (LMSDiscreteScheduler,)