mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
PNDM API Updates, Tests Cleaning (#103)
* organize PNDM tests, begin API change * clean timestep API PNDM * update pipeline PNDM * fix typo * API clean round 2 * small nit
This commit is contained in:
@@ -43,19 +43,16 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps)
|
||||
for t in tqdm(range(len(prk_time_steps))):
|
||||
t_orig = prk_time_steps[t]
|
||||
model_output = self.unet(image, t_orig)["sample"]
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
|
||||
image = self.scheduler.step_prk(model_output, t, image, num_inference_steps)["prev_sample"]
|
||||
image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"]
|
||||
|
||||
timesteps = self.scheduler.get_time_steps(num_inference_steps)
|
||||
for t in tqdm(range(len(timesteps))):
|
||||
t_orig = timesteps[t]
|
||||
model_output = self.unet(image, t_orig)["sample"]
|
||||
for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
|
||||
image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"]
|
||||
image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"]
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
import pdb
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
@@ -71,8 +72,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.one = np.array(1.0)
|
||||
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
||||
@@ -82,49 +81,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.cur_model_output = 0
|
||||
self.cur_sample = None
|
||||
self.ets = []
|
||||
self.prk_time_steps = {}
|
||||
self.time_steps = {}
|
||||
self.set_prk_mode()
|
||||
|
||||
def get_prk_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.prk_time_steps:
|
||||
return self.prk_time_steps[num_inference_steps]
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||
self.prk_timesteps = None
|
||||
self.plms_timesteps = None
|
||||
|
||||
inference_step_times = list(
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = list(
|
||||
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
|
||||
)
|
||||
|
||||
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
)
|
||||
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
|
||||
self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
|
||||
self.plms_timesteps = list(reversed(self.timesteps[:-3]))
|
||||
|
||||
return self.prk_time_steps[num_inference_steps]
|
||||
|
||||
def get_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.time_steps:
|
||||
return self.time_steps[num_inference_steps]
|
||||
|
||||
inference_step_times = list(
|
||||
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
|
||||
)
|
||||
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
|
||||
|
||||
return self.time_steps[num_inference_steps]
|
||||
|
||||
def set_prk_mode(self):
|
||||
self.mode = "prk"
|
||||
|
||||
def set_plms_mode(self):
|
||||
self.mode = "plms"
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
if self.mode == "prk":
|
||||
return self.step_prk(*args, **kwargs)
|
||||
if self.mode == "plms":
|
||||
return self.step_plms(*args, **kwargs)
|
||||
|
||||
raise ValueError(f"mode {self.mode} does not exist.")
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def step_prk(
|
||||
self,
|
||||
@@ -138,7 +117,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
solution to the differential equation.
|
||||
"""
|
||||
t = timestep
|
||||
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
|
||||
prk_time_steps = self.prk_timesteps
|
||||
|
||||
t_orig = prk_time_steps[t // 4 * 4]
|
||||
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
|
||||
@@ -180,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"for more information."
|
||||
)
|
||||
|
||||
timesteps = self.get_time_steps(num_inference_steps)
|
||||
timesteps = self.plms_timesteps
|
||||
|
||||
t_orig = timesteps[t]
|
||||
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
|
||||
|
||||
@@ -70,7 +70,6 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
@@ -102,7 +101,6 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
@@ -375,33 +373,40 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs_pmls(self, time_step=0, **config):
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
|
||||
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
sample = self.dummy_sample
|
||||
@@ -409,74 +414,127 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_pytorch_equal_numpy(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
sample_pt = torch.tensor(sample)
|
||||
residual_pt = 0.1 * sample_pt
|
||||
dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler_pt.ets = dummy_past_residuals_pt[:]
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler_pt.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
# copy over dummy past residuals
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_timesteps_pmls(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs_pmls(num_train_timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_betas_pmls(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
|
||||
self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_schedules_pmls(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_time_indices_pmls(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward_pmls(time_step=t)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_inference_steps_pmls(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_inference_pmls_no_past_residuals(self):
|
||||
def test_inference_plms_no_past_residuals(self):
|
||||
with self.assertRaises(ValueError):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
|
||||
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
@@ -486,20 +544,15 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
|
||||
for t in range(len(prk_time_steps)):
|
||||
t_orig = prk_time_steps[t]
|
||||
residual = model(sample, t_orig)
|
||||
for i, t in enumerate(scheduler.prk_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"]
|
||||
|
||||
sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"]
|
||||
|
||||
timesteps = scheduler.get_time_steps(num_inference_steps)
|
||||
for t in range(len(timesteps)):
|
||||
t_orig = timesteps[t]
|
||||
residual = model(sample, t_orig)
|
||||
|
||||
sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"]
|
||||
for i, t in enumerate(scheduler.plms_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"]
|
||||
|
||||
result_sum = np.sum(np.abs(sample))
|
||||
result_mean = np.mean(np.abs(sample))
|
||||
@@ -562,7 +615,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
@@ -591,7 +643,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user