diff --git a/src/diffusers/schedulers/scheduling_ald.py b/src/diffusers/schedulers/scheduling_ald.py index bb3633e3ee..82d3b8230e 100644 --- a/src/diffusers/schedulers/scheduling_ald.py +++ b/src/diffusers/schedulers/scheduling_ald.py @@ -132,7 +132,7 @@ class ALDScheduler(SchedulerMixin, ConfigMixin): self.set_timesteps(num_inference_steps) self.sigmas = torch.tensor( - torch.exp(torch.linspace(torch.log(sigma_min), torch.log(sigma_max), num_inference_steps)), + np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)), dtype=torch.float32, ) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 194fb66f66..0e1ddd9e59 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -20,6 +20,7 @@ import numpy as np import torch from diffusers import ( + ALDScheduler, DDIMScheduler, DDPMScheduler, IPNDMScheduler, @@ -875,6 +876,112 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) +class ALDSchedulerTest(unittest.TestCase): + # TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration), similar to SDE VE + scheduler_classes = (ALDScheduler,) + forward_default_kwargs = () + + @property + def dummy_sample(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + sample = torch.rand((batch_size, num_channels, height, width)) + + return sample + + @property + def dummy_sample_deter(self): + batch_size = 4 + num_channels = 3 + height = 8 + width = 8 + + num_elems = batch_size * num_channels * height * width + sample = torch.arange(num_elems) + sample = sample.reshape(num_channels, height, width, batch_size) + sample = sample / num_elems + sample = sample.permute(3, 0, 1, 2) + + return sample + + def dummy_model(self): + def model(sample, t, *args): + return sample * t / (t + 1) + + return model + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 100, + "sigma_min": 0.01, + "sigma_max": 1.0, + "step_lr": 0.00002, + } + + config.update(**kwargs) + return config + + def check_over_configs(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_config(tmpdirname) + + output = scheduler.step( + residual, time_step, sample, generator=torch.manual_seed(0), **kwargs + ).prev_sample + new_output = new_scheduler.step( + residual, time_step, sample, generator=torch.manual_seed(0), **kwargs + ).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def check_over_forward(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_config(tmpdirname) + + output = scheduler.step( + residual, time_step, sample, generator=torch.manual_seed(0), **kwargs + ).prev_sample + new_output = new_scheduler.step( + residual, time_step, sample, generator=torch.manual_seed(0), **kwargs + ).prev_sample + + assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_timesteps(self): + for timesteps in [10, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_sigmas(self): + for sigma_min, sigma_max in zip([0.0001, 0.001, 0.01], [1, 1, 1]): + self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max) + + def test_time_indices(self): + for t in [0.1, 0.5, 0.75]: + self.check_over_forward(time_step=t) class LMSDiscreteSchedulerTest(SchedulerCommonTest): scheduler_classes = (LMSDiscreteScheduler,)