1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

PNDM API Updates, Tests Cleaning (#103)

* organize PNDM tests, begin API change

* clean timestep API PNDM

* update pipeline PNDM

* fix typo

* API clean round 2

* small nit
This commit is contained in:
Nathan Lambert
2022-07-20 12:47:39 -07:00
committed by GitHub
parent 76f9b52289
commit 889aa6008c
3 changed files with 128 additions and 101 deletions

View File

@@ -43,19 +43,16 @@ class PNDMPipeline(DiffusionPipeline):
)
image = image.to(torch_device)
prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps)
for t in tqdm(range(len(prk_time_steps))):
t_orig = prk_time_steps[t]
model_output = self.unet(image, t_orig)["sample"]
self.scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
model_output = self.unet(image, t)["sample"]
image = self.scheduler.step_prk(model_output, t, image, num_inference_steps)["prev_sample"]
image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"]
timesteps = self.scheduler.get_time_steps(num_inference_steps)
for t in tqdm(range(len(timesteps))):
t_orig = timesteps[t]
model_output = self.unet(image, t_orig)["sample"]
for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
model_output = self.unet(image, t)["sample"]
image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"]
image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

View File

@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import pdb
from typing import Union
import numpy as np
@@ -71,8 +72,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format)
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2.
@@ -82,49 +81,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.cur_model_output = 0
self.cur_sample = None
self.ets = []
self.prk_time_steps = {}
self.time_steps = {}
self.set_prk_mode()
def get_prk_time_steps(self, num_inference_steps):
if num_inference_steps in self.prk_time_steps:
return self.prk_time_steps[num_inference_steps]
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.prk_timesteps = None
self.plms_timesteps = None
inference_step_times = list(
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self.timesteps[:-3]))
return self.prk_time_steps[num_inference_steps]
def get_time_steps(self, num_inference_steps):
if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps]
inference_step_times = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
return self.time_steps[num_inference_steps]
def set_prk_mode(self):
self.mode = "prk"
def set_plms_mode(self):
self.mode = "plms"
def step(self, *args, **kwargs):
if self.mode == "prk":
return self.step_prk(*args, **kwargs)
if self.mode == "plms":
return self.step_plms(*args, **kwargs)
raise ValueError(f"mode {self.mode} does not exist.")
self.set_format(tensor_format=self.tensor_format)
def step_prk(
self,
@@ -138,7 +117,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
solution to the differential equation.
"""
t = timestep
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
prk_time_steps = self.prk_timesteps
t_orig = prk_time_steps[t // 4 * 4]
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
@@ -180,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information."
)
timesteps = self.get_time_steps(num_inference_steps)
timesteps = self.plms_timesteps
t_orig = timesteps[t]
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]

View File

@@ -70,7 +70,6 @@ class SchedulerCommonTest(unittest.TestCase):
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
sample = self.dummy_sample
residual = 0.1 * sample
@@ -102,7 +101,6 @@ class SchedulerCommonTest(unittest.TestCase):
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
@@ -375,33 +373,40 @@ class PNDMSchedulerTest(SchedulerCommonTest):
config.update(**kwargs)
return config
def check_over_configs_pmls(self, time_step=0, **config):
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(kwargs["num_inference_steps"])
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler.set_plms_mode()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode()
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
pass
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
sample = self.dummy_sample
@@ -409,74 +414,127 @@ class PNDMSchedulerTest(SchedulerCommonTest):
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(kwargs["num_inference_steps"])
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler.set_plms_mode()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode()
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
sample_pt = torch.tensor(sample)
residual_pt = 0.1 * sample_pt
dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
# copy over dummy past residuals
scheduler_pt.ets = dummy_past_residuals_pt[:]
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
# copy over dummy past residuals
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler.ets = dummy_past_residuals[:]
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_timesteps_pmls(self):
for timesteps in [100, 1000]:
self.check_over_configs_pmls(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_betas_pmls(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_schedules_pmls(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_time_indices(self):
for t in [1, 5, 10]:
self.check_over_forward(time_step=t)
def test_time_indices_pmls(self):
for t in [1, 5, 10]:
self.check_over_forward_pmls(time_step=t)
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_inference_steps_pmls(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)
def test_inference_pmls_no_past_residuals(self):
def test_inference_plms_no_past_residuals(self):
with self.assertRaises(ValueError):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_plms_mode()
scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
@@ -486,20 +544,15 @@ class PNDMSchedulerTest(SchedulerCommonTest):
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
for t in range(len(prk_time_steps)):
t_orig = prk_time_steps[t]
residual = model(sample, t_orig)
for i, t in enumerate(scheduler.prk_timesteps):
residual = model(sample, t)
sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"]
sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"]
timesteps = scheduler.get_time_steps(num_inference_steps)
for t in range(len(timesteps)):
t_orig = timesteps[t]
residual = model(sample, t_orig)
sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"]
for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t)
sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"]
result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample))
@@ -562,7 +615,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
kwargs = dict(self.forward_default_kwargs)
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
sample = self.dummy_sample
residual = 0.1 * sample
@@ -591,7 +643,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)