mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add sketch of tests (need more changes)
This commit is contained in:
@@ -132,7 +132,7 @@ class ALDScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
self.sigmas = torch.tensor(
|
||||
torch.exp(torch.linspace(torch.log(sigma_min), torch.log(sigma_max), num_inference_steps)),
|
||||
np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
ALDScheduler,
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
IPNDMScheduler,
|
||||
@@ -875,6 +876,112 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
class ALDSchedulerTest(unittest.TestCase):
|
||||
# TODO adapt with class SchedulerCommonTest (scheduler needs Numpy Integration), similar to SDE VE
|
||||
scheduler_classes = (ALDScheduler,)
|
||||
forward_default_kwargs = ()
|
||||
|
||||
@property
|
||||
def dummy_sample(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
sample = torch.rand((batch_size, num_channels, height, width))
|
||||
|
||||
return sample
|
||||
|
||||
@property
|
||||
def dummy_sample_deter(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
num_elems = batch_size * num_channels * height * width
|
||||
sample = torch.arange(num_elems)
|
||||
sample = sample.reshape(num_channels, height, width, batch_size)
|
||||
sample = sample / num_elems
|
||||
sample = sample.permute(3, 0, 1, 2)
|
||||
|
||||
return sample
|
||||
|
||||
def dummy_model(self):
|
||||
def model(sample, t, *args):
|
||||
return sample * t / (t + 1)
|
||||
|
||||
return model
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 100,
|
||||
"sigma_min": 0.01,
|
||||
"sigma_max": 1.0,
|
||||
"step_lr": 0.00002,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step(
|
||||
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
|
||||
).prev_sample
|
||||
new_output = new_scheduler.step(
|
||||
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
|
||||
).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step(
|
||||
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
|
||||
).prev_sample
|
||||
new_output = new_scheduler.step(
|
||||
residual, time_step, sample, generator=torch.manual_seed(0), **kwargs
|
||||
).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [10, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_sigmas(self):
|
||||
for sigma_min, sigma_max in zip([0.0001, 0.001, 0.01], [1, 1, 1]):
|
||||
self.check_over_configs(sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [0.1, 0.5, 0.75]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (LMSDiscreteScheduler,)
|
||||
|
||||
Reference in New Issue
Block a user