diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index d98699dd25..33ec1a3e98 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -43,19 +43,16 @@ class PNDMPipeline(DiffusionPipeline): ) image = image.to(torch_device) - prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps) - for t in tqdm(range(len(prk_time_steps))): - t_orig = prk_time_steps[t] - model_output = self.unet(image, t_orig)["sample"] + self.scheduler.set_timesteps(num_inference_steps) + for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)): + model_output = self.unet(image, t)["sample"] - image = self.scheduler.step_prk(model_output, t, image, num_inference_steps)["prev_sample"] + image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"] - timesteps = self.scheduler.get_time_steps(num_inference_steps) - for t in tqdm(range(len(timesteps))): - t_orig = timesteps[t] - model_output = self.unet(image, t_orig)["sample"] + for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)): + model_output = self.unet(image, t)["sample"] - image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"] + image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"] image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 216c4a715f..2c157e05d3 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,6 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math +import pdb from typing import Union import numpy as np @@ -71,8 +72,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.one = np.array(1.0) - self.set_format(tensor_format=tensor_format) - # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. @@ -82,49 +81,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.cur_model_output = 0 self.cur_sample = None self.ets = [] - self.prk_time_steps = {} - self.time_steps = {} - self.set_prk_mode() - def get_prk_time_steps(self, num_inference_steps): - if num_inference_steps in self.prk_time_steps: - return self.prk_time_steps[num_inference_steps] + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.prk_timesteps = None + self.plms_timesteps = None - inference_step_times = list( + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps): + self.num_inference_steps = num_inference_steps + self.timesteps = list( range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) ) - prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile( + prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile( np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order ) - self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1])) + self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1])) + self.plms_timesteps = list(reversed(self.timesteps[:-3])) - return self.prk_time_steps[num_inference_steps] - - def get_time_steps(self, num_inference_steps): - if num_inference_steps in self.time_steps: - return self.time_steps[num_inference_steps] - - inference_step_times = list( - range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) - ) - self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3])) - - return self.time_steps[num_inference_steps] - - def set_prk_mode(self): - self.mode = "prk" - - def set_plms_mode(self): - self.mode = "plms" - - def step(self, *args, **kwargs): - if self.mode == "prk": - return self.step_prk(*args, **kwargs) - if self.mode == "plms": - return self.step_plms(*args, **kwargs) - - raise ValueError(f"mode {self.mode} does not exist.") + self.set_format(tensor_format=self.tensor_format) def step_prk( self, @@ -138,7 +117,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): solution to the differential equation. """ t = timestep - prk_time_steps = self.get_prk_time_steps(num_inference_steps) + prk_time_steps = self.prk_timesteps t_orig = prk_time_steps[t // 4 * 4] t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)] @@ -180,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): "for more information." ) - timesteps = self.get_time_steps(num_inference_steps) + timesteps = self.plms_timesteps t_orig = timesteps[t] t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)] diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index a409426a64..3059da1661 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -70,7 +70,6 @@ class SchedulerCommonTest(unittest.TestCase): num_inference_steps = kwargs.pop("num_inference_steps", None) for scheduler_class in self.scheduler_classes: - scheduler_class = self.scheduler_classes[0] sample = self.dummy_sample residual = 0.1 * sample @@ -102,7 +101,6 @@ class SchedulerCommonTest(unittest.TestCase): sample = self.dummy_sample residual = 0.1 * sample - scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) @@ -375,33 +373,40 @@ class PNDMSchedulerTest(SchedulerCommonTest): config.update(**kwargs) return config - def check_over_configs_pmls(self, time_step=0, **config): + def check_over_configs(self, time_step=0, **config): kwargs = dict(self.forward_default_kwargs) sample = self.dummy_sample residual = 0.1 * sample dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] for scheduler_class in self.scheduler_classes: - scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(kwargs["num_inference_steps"]) # copy over dummy past residuals scheduler.ets = dummy_past_residuals[:] - scheduler.set_plms_mode() with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler.set_timesteps(kwargs["num_inference_steps"]) # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] - new_scheduler.set_plms_mode() - output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def check_over_forward_pmls(self, time_step=0, **forward_kwargs): + output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] + + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + pass + + def check_over_forward(self, time_step=0, **forward_kwargs): kwargs = dict(self.forward_default_kwargs) kwargs.update(forward_kwargs) sample = self.dummy_sample @@ -409,74 +414,127 @@ class PNDMSchedulerTest(SchedulerCommonTest): dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] for scheduler_class in self.scheduler_classes: - scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(kwargs["num_inference_steps"]) + # copy over dummy past residuals scheduler.ets = dummy_past_residuals[:] - scheduler.set_plms_mode() with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] - new_scheduler.set_plms_mode() + new_scheduler.set_timesteps(kwargs["num_inference_steps"]) - output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] + + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_pytorch_equal_numpy(self): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + + sample_pt = torch.tensor(sample) + residual_pt = 0.1 * sample_pt + dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05] + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + # copy over dummy past residuals + scheduler.ets = dummy_past_residuals[:] + + scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) + # copy over dummy past residuals + scheduler_pt.ets = dummy_past_residuals_pt[:] + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + scheduler_pt.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] + output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"] + + assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + + output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] + output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"] + + assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + # copy over dummy past residuals + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + scheduler.ets = dummy_past_residuals[:] + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"] + output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"] + output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + def test_timesteps(self): for timesteps in [100, 1000]: self.check_over_configs(num_train_timesteps=timesteps) - def test_timesteps_pmls(self): - for timesteps in [100, 1000]: - self.check_over_configs_pmls(num_train_timesteps=timesteps) - def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - def test_betas_pmls(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): - self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end) - def test_schedules(self): for schedule in ["linear", "squaredcos_cap_v2"]: self.check_over_configs(beta_schedule=schedule) - def test_schedules_pmls(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - def test_time_indices(self): for t in [1, 5, 10]: self.check_over_forward(time_step=t) - def test_time_indices_pmls(self): - for t in [1, 5, 10]: - self.check_over_forward_pmls(time_step=t) - def test_inference_steps(self): for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) - def test_inference_steps_pmls(self): - for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): - self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps) - - def test_inference_pmls_no_past_residuals(self): + def test_inference_plms_no_past_residuals(self): with self.assertRaises(ValueError): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - scheduler.set_plms_mode() - - scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"] + scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"] def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] @@ -486,20 +544,15 @@ class PNDMSchedulerTest(SchedulerCommonTest): num_inference_steps = 10 model = self.dummy_model() sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) - prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps) - for t in range(len(prk_time_steps)): - t_orig = prk_time_steps[t] - residual = model(sample, t_orig) + for i, t in enumerate(scheduler.prk_timesteps): + residual = model(sample, t) + sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"] - sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"] - - timesteps = scheduler.get_time_steps(num_inference_steps) - for t in range(len(timesteps)): - t_orig = timesteps[t] - residual = model(sample, t_orig) - - sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"] + for i, t in enumerate(scheduler.plms_timesteps): + residual = model(sample, t) + sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"] result_sum = np.sum(np.abs(sample)) result_mean = np.mean(np.abs(sample)) @@ -562,7 +615,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): kwargs = dict(self.forward_default_kwargs) for scheduler_class in self.scheduler_classes: - scheduler_class = self.scheduler_classes[0] sample = self.dummy_sample residual = 0.1 * sample @@ -591,7 +643,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): sample = self.dummy_sample residual = 0.1 * sample - scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config)