From 3f1861ee46f83c81efe3a5458fe6fef908941a78 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Sun, 21 Aug 2022 22:23:59 -0700 Subject: [PATCH] hotfix for pdnm test (#220) --- tests/test_scheduler.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index f125bb0dfb..b9e9c15bcd 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -426,16 +426,18 @@ class PNDMSchedulerTest(SchedulerCommonTest): scheduler = scheduler_class(**scheduler_config) scheduler.set_timesteps(num_inference_steps) - # copy over dummy past residuals + # copy over dummy past residuals (must be after setting timesteps) scheduler.ets = dummy_past_residuals[:] with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) # copy over dummy past residuals - new_scheduler.ets = dummy_past_residuals[:] new_scheduler.set_timesteps(num_inference_steps) + # copy over dummy past residual (must be after setting timesteps) + new_scheduler.ets = dummy_past_residuals[:] + output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] @@ -461,12 +463,8 @@ class PNDMSchedulerTest(SchedulerCommonTest): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(tensor_format="np", **scheduler_config) - # copy over dummy past residuals - scheduler.ets = dummy_past_residuals[:] scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) - # copy over dummy past residuals - scheduler_pt.ets = dummy_past_residuals_pt[:] if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -474,6 +472,10 @@ class PNDMSchedulerTest(SchedulerCommonTest): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps + # copy over dummy past residuals (must be done after set_timesteps) + scheduler.ets = dummy_past_residuals[:] + scheduler_pt.ets = dummy_past_residuals_pt[:] + output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"] output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"] assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" @@ -494,15 +496,16 @@ class PNDMSchedulerTest(SchedulerCommonTest): sample = self.dummy_sample residual = 0.1 * sample - # copy over dummy past residuals - dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] - scheduler.ets = dummy_past_residuals[:] if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps + # copy over dummy past residuals (must be done after set_timesteps) + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + scheduler.ets = dummy_past_residuals[:] + output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"] output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]