mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
PNDM API Updates, Tests Cleaning (#103)
* organize PNDM tests, begin API change * clean timestep API PNDM * update pipeline PNDM * fix typo * API clean round 2 * small nit
This commit is contained in:
@@ -70,7 +70,6 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
@@ -102,7 +101,6 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
@@ -375,33 +373,40 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs_pmls(self, time_step=0, **config):
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
|
||||
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
sample = self.dummy_sample
|
||||
@@ -409,74 +414,127 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_pytorch_equal_numpy(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
sample_pt = torch.tensor(sample)
|
||||
residual_pt = 0.1 * sample_pt
|
||||
dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler_pt.ets = dummy_past_residuals_pt[:]
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler_pt.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
# copy over dummy past residuals
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_timesteps_pmls(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs_pmls(num_train_timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_betas_pmls(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
|
||||
self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_schedules_pmls(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_time_indices_pmls(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward_pmls(time_step=t)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_inference_steps_pmls(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_inference_pmls_no_past_residuals(self):
|
||||
def test_inference_plms_no_past_residuals(self):
|
||||
with self.assertRaises(ValueError):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
|
||||
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
@@ -486,20 +544,15 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
|
||||
for t in range(len(prk_time_steps)):
|
||||
t_orig = prk_time_steps[t]
|
||||
residual = model(sample, t_orig)
|
||||
for i, t in enumerate(scheduler.prk_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"]
|
||||
|
||||
sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"]
|
||||
|
||||
timesteps = scheduler.get_time_steps(num_inference_steps)
|
||||
for t in range(len(timesteps)):
|
||||
t_orig = timesteps[t]
|
||||
residual = model(sample, t_orig)
|
||||
|
||||
sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"]
|
||||
for i, t in enumerate(scheduler.plms_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"]
|
||||
|
||||
result_sum = np.sum(np.abs(sample))
|
||||
result_mean = np.mean(np.abs(sample))
|
||||
@@ -562,7 +615,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
@@ -591,7 +643,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user