1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

hotfix for pdnm test (#220)

This commit is contained in:
Nathan Lambert
2022-08-21 22:23:59 -07:00
committed by GitHub
parent 6a03060c45
commit 3f1861ee46

View File

@@ -426,16 +426,18 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residuals
# copy over dummy past residuals (must be after setting timesteps)
scheduler.ets = dummy_past_residuals[:]
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_timesteps(num_inference_steps)
# copy over dummy past residual (must be after setting timesteps)
new_scheduler.ets = dummy_past_residuals[:]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
@@ -461,12 +463,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(tensor_format="np", **scheduler_config)
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
# copy over dummy past residuals
scheduler_pt.ets = dummy_past_residuals_pt[:]
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
@@ -474,6 +472,10 @@ class PNDMSchedulerTest(SchedulerCommonTest):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
scheduler.ets = dummy_past_residuals[:]
scheduler_pt.ets = dummy_past_residuals_pt[:]
output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
@@ -494,15 +496,16 @@ class PNDMSchedulerTest(SchedulerCommonTest):
sample = self.dummy_sample
residual = 0.1 * sample
# copy over dummy past residuals
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler.ets = dummy_past_residuals[:]
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
# copy over dummy past residuals (must be done after set_timesteps)
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler.ets = dummy_past_residuals[:]
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]