From 394243ce98c79e273732f8e09b2e5c50ee9a9bd7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 21 Jul 2022 01:50:12 +0000 Subject: [PATCH] finish pndm sampler --- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 7 +- src/diffusers/pipelines/pndm/pipeline_pndm.py | 17 +++-- .../score_sde_ve/pipeline_score_sde_ve.py | 16 ++--- src/diffusers/schedulers/scheduling_ddim.py | 4 +- src/diffusers/schedulers/scheduling_ddpm.py | 1 - src/diffusers/schedulers/scheduling_pndm.py | 66 +++++++++++-------- tests/test_scheduler.py | 36 +++++----- 7 files changed, 77 insertions(+), 70 deletions(-) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index c947827f01..bab1c245f3 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -48,11 +48,8 @@ class DDPMPipeline(DiffusionPipeline): # 1. predict noise model_output model_output = self.unet(image, t)["sample"] - # 2. predict previous mean of image x_t-1 - pred_prev_image = self.scheduler.step(model_output, t, image)["prev_sample"] - - # 3. set current image to prev_image: x_t -> x_t-1 - image = pred_prev_image + # 2. compute previous image: x_t -> t_t-1 + image = self.scheduler.step(model_output, t, image)["prev_sample"] image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 88e557f967..3c3c36f0dc 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -44,15 +44,20 @@ class PNDMPipeline(DiffusionPipeline): image = image.to(torch_device) self.scheduler.set_timesteps(num_inference_steps) - for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)): + for t in tqdm(self.scheduler.timesteps): model_output = self.unet(image, t)["sample"] - image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"] + image = self.scheduler.step(model_output, t, image)["prev_sample"] - for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)): - model_output = self.unet(image, t)["sample"] - - image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"] + # for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)): + # model_output = self.unet(image, t)["sample"] + # + # image = self.scheduler.step_prk(model_output, t, image, i=i)["prev_sample"] + # + # for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)): + # model_output = self.unet(image, t)["sample"] + # + # image = self.scheduler.step_plms(model_output, t, image, i=i)["prev_sample"] image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index ba8fbd762c..6344a578b9 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -28,21 +28,15 @@ class ScoreSdeVePipeline(DiffusionPipeline): for i, t in tqdm(enumerate(self.scheduler.timesteps)): sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) + # correction step for _ in range(self.scheduler.correct_steps): - model_output = self.model(sample, sigma_t) - - if isinstance(model_output, dict): - model_output = model_output["sample"] - + model_output = self.model(sample, sigma_t)["sample"] sample = self.scheduler.step_correct(model_output, sample)["prev_sample"] - with torch.no_grad(): - model_output = model(sample, sigma_t) - - if isinstance(model_output, dict): - model_output = model_output["sample"] - + # prediction step + model_output = model(sample, sigma_t)["sample"] output = self.scheduler.step_pred(model_output, t, sample) + sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] sample = sample.clamp(0, 1) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 83c313acf3..ed76873f8a 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -106,8 +106,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], - eta, - use_clipped_model_output=False, + eta: float = 0.0, + use_clipped_model_output: bool = False, generator=None, ): # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 0a0a29e17e..d8f75f4bdd 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -56,7 +56,6 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): beta_end=0.02, beta_schedule="linear", trained_betas=None, - timestep_values=None, variance_type="fixed_small", clip_sample=True, tensor_format="pt", diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index df30a269f0..03c0f913f3 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,7 +15,6 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -import pdb from typing import Union import numpy as np @@ -79,78 +78,91 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): # running values self.cur_model_output = 0 + self.counter = 0 self.cur_sample = None self.ets = [] # setable values self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.prk_timesteps = None self.plms_timesteps = None + self.timesteps = None self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) def set_timesteps(self, num_inference_steps): self.num_inference_steps = num_inference_steps - self.timesteps = list( + self._timesteps = list( range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) ) - prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile( + prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile( np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order ) - self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1])) - self.plms_timesteps = list(reversed(self.timesteps[:-3])) + self.prk_timesteps = list(reversed(prk_timesteps[:-1].repeat(2)[1:-1])) + self.plms_timesteps = list(reversed(self._timesteps[:-3])) + self.timesteps = self.prk_timesteps + self.plms_timesteps + self.counter = 0 self.set_format(tensor_format=self.tensor_format) + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + ): + if self.counter < len(self.prk_timesteps): + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample) + else: + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample) + def step_prk( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], - num_inference_steps, ): """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. """ - t = timestep - prk_time_steps = self.prk_timesteps + diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2 + prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1]) + timestep = self.prk_timesteps[self.counter // 4 * 4] - t_orig = prk_time_steps[t // 4 * 4] - t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)] - - if t % 4 == 0: + if self.counter % 4 == 0: self.cur_model_output += 1 / 6 * model_output self.ets.append(model_output) self.cur_sample = sample - elif (t - 1) % 4 == 0: + elif (self.counter - 1) % 4 == 0: self.cur_model_output += 1 / 3 * model_output - elif (t - 2) % 4 == 0: + elif (self.counter - 2) % 4 == 0: self.cur_model_output += 1 / 3 * model_output - elif (t - 3) % 4 == 0: + elif (self.counter - 3) % 4 == 0: model_output = self.cur_model_output + 1 / 6 * model_output self.cur_model_output = 0 # cur_sample should not be `None` cur_sample = self.cur_sample if self.cur_sample is not None else sample - return {"prev_sample": self.get_prev_sample(cur_sample, t_orig, t_orig_prev, model_output)} + prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + self.counter += 1 + + return {"prev_sample": prev_sample} def step_plms( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], - num_inference_steps, ): """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. """ - t = timestep if len(self.ets) < 3: raise ValueError( f"{self.__class__} can only be run AFTER scheduler has been run " @@ -159,17 +171,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): "for more information." ) - timesteps = self.plms_timesteps - - t_orig = timesteps[t] - t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)] + prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0) self.ets.append(model_output) model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) - return {"prev_sample": self.get_prev_sample(sample, t_orig, t_orig_prev, model_output)} + prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) + self.counter += 1 - def get_prev_sample(self, sample, t_orig, t_orig_prev, model_output): + return {"prev_sample": prev_sample} + + def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf # this function computes x_(t−δ) using the formula of (9) # Note that x_t needs to be added to both sides of the equation @@ -182,8 +194,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): # sample -> x_t # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) - alpha_prod_t = self.alphas_cumprod[t_orig + 1] - alpha_prod_t_prev = self.alphas_cumprod[t_orig_prev + 1] + alpha_prod_t = self.alphas_cumprod[timestep + 1] + alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1] beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 985ff3ae64..f125bb0dfb 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pdb import tempfile import unittest @@ -383,6 +382,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): def check_over_configs(self, time_step=0, **config): kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) sample = self.dummy_sample residual = 0.1 * sample dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] @@ -390,14 +390,14 @@ class PNDMSchedulerTest(SchedulerCommonTest): for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) - scheduler.set_timesteps(kwargs["num_inference_steps"]) + scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals scheduler.ets = dummy_past_residuals[:] with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) - new_scheduler.set_timesteps(kwargs["num_inference_steps"]) + new_scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] @@ -416,7 +416,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): def check_over_forward(self, time_step=0, **forward_kwargs): kwargs = dict(self.forward_default_kwargs) - kwargs.update(forward_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) sample = self.dummy_sample residual = 0.1 * sample dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] @@ -424,7 +424,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): for scheduler_class in self.scheduler_classes: scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - scheduler.set_timesteps(kwargs["num_inference_steps"]) + scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals scheduler.ets = dummy_past_residuals[:] @@ -434,7 +434,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): new_scheduler = scheduler_class.from_config(tmpdirname) # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] - new_scheduler.set_timesteps(kwargs["num_inference_steps"]) + new_scheduler.set_timesteps(num_inference_steps) output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] @@ -474,12 +474,12 @@ class PNDMSchedulerTest(SchedulerCommonTest): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] - output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"] + output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"] + output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"] assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] - output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"] + output = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"] + output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs)["prev_sample"] assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" @@ -503,14 +503,14 @@ class PNDMSchedulerTest(SchedulerCommonTest): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"] - output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] + output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"] + output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"] self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) - output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"] - output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] + output_0 = scheduler.step_plms(residual, 0, sample, **kwargs)["prev_sample"] + output_1 = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"] self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) @@ -541,7 +541,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"] + scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample)["prev_sample"] def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] @@ -555,11 +555,11 @@ class PNDMSchedulerTest(SchedulerCommonTest): for i, t in enumerate(scheduler.prk_timesteps): residual = model(sample, t) - sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"] + sample = scheduler.step_prk(residual, i, sample)["prev_sample"] for i, t in enumerate(scheduler.plms_timesteps): residual = model(sample, t) - sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"] + sample = scheduler.step_plms(residual, i, sample)["prev_sample"] result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) @@ -706,7 +706,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): model_output = model(sample, sigma_t) output = scheduler.step_pred(model_output, t, sample, **kwargs) - sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] + sample, _ = output["prev_sample"], output["prev_sample_mean"] result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))