From b2274ece73bdba5d72b3c82d89af271f5d70a68b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Jun 2022 15:51:03 +0200 Subject: [PATCH] finish pndm scheduler --- src/diffusers/pipelines/pipeline_pndm.py | 6 +- src/diffusers/schedulers/scheduling_pndm.py | 100 ++++++++---- tests/test_scheduler.py | 162 +++++++++++++++++++- 3 files changed, 230 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_pndm.py b/src/diffusers/pipelines/pipeline_pndm.py index 93d735a8a8..a19f933ed1 100644 --- a/src/diffusers/pipelines/pipeline_pndm.py +++ b/src/diffusers/pipelines/pipeline_pndm.py @@ -42,9 +42,9 @@ class PNDM(DiffusionPipeline): ) image = image.to(torch_device) - warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps) - for t in tqdm.tqdm(range(len(warmup_time_steps))): - t_orig = warmup_time_steps[t] + prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps) + for t in tqdm.tqdm(range(len(prk_time_steps))): + t_orig = prk_time_steps[t] residual = self.unet(image, t_orig) image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index d0f860f2a5..686c31140c 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -56,15 +56,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf - # mainly at equations (12) and (13) and the Algorithm 2. + # mainly at formula (9), (12), (13) and the Algorithm 2. self.pndm_order = 4 # running values self.cur_residual = 0 self.cur_sample = None self.ets = [] - self.warmup_time_steps = {} + self.prk_time_steps = {} self.time_steps = {} + self.set_prk_mode() def get_alpha(self, time_step): return self.alphas[time_step] @@ -77,18 +78,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return self.one return self.alphas_cumprod[time_step] - def get_warmup_time_steps(self, num_inference_steps): - if num_inference_steps in self.warmup_time_steps: - return self.warmup_time_steps[num_inference_steps] + def get_prk_time_steps(self, num_inference_steps): + if num_inference_steps in self.prk_time_steps: + return self.prk_time_steps[num_inference_steps] inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps)) - warmup_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile( + prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile( np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order ) - self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1])) + self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1])) - return self.warmup_time_steps[num_inference_steps] + return self.prk_time_steps[num_inference_steps] def get_time_steps(self, num_inference_steps): if num_inference_steps in self.time_steps: @@ -99,12 +100,25 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return self.time_steps[num_inference_steps] - def step_prk(self, residual, sample, t, num_inference_steps): - # TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here - warmup_time_steps = self.get_warmup_time_steps(num_inference_steps) + def set_prk_mode(self): + self.mode = "prk" - t_prev = warmup_time_steps[t // 4 * 4] - t_next = warmup_time_steps[min(t + 1, len(warmup_time_steps) - 1)] + def set_plms_mode(self): + self.mode = "plms" + + def step(self, *args, **kwargs): + if self.mode == "prk": + return self.step_prk(*args, **kwargs) + if self.mode == "plms": + return self.step_plms(*args, **kwargs) + + raise ValueError(f"mode {self.mode} does not exist.") + + def step_prk(self, residual, sample, t, num_inference_steps): + prk_time_steps = self.get_prk_time_steps(num_inference_steps) + + t_orig = prk_time_steps[t // 4 * 4] + t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)] if t % 4 == 0: self.cur_residual += 1 / 6 * residual @@ -118,33 +132,63 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): residual = self.cur_residual + 1 / 6 * residual self.cur_residual = 0 - return self.transfer(self.cur_sample, t_prev, t_next, residual) + # cur_sample should not be `None` + cur_sample = self.cur_sample if self.cur_sample is not None else sample + + return self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual) def step_plms(self, residual, sample, t, num_inference_steps): + if len(self.ets) < 3: + raise ValueError( + f"{self.__class__} can only be run AFTER scheduler has been run " + "in 'prk' mode for at least 12 iterations " + "See: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_pndm.py " + "for more information." + ) + timesteps = self.get_time_steps(num_inference_steps) - t_prev = timesteps[t] - t_next = timesteps[min(t + 1, len(timesteps) - 1)] + t_orig = timesteps[t] + t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)] self.ets.append(residual) residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) - return self.transfer(sample, t_prev, t_next, residual) + return self.get_prev_sample(sample, t_orig, t_orig_prev, residual) - def transfer(self, x, t, t_next, et): - # TODO(Patrick): clean up to be compatible with numpy and give better names + def get_prev_sample(self, sample, t_orig, t_orig_prev, residual): + # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf + # this function computes x_(t−δ) using the formula of (9) + # Note that x_t needs to be added to both sides of the equation - alphas_cump = self.alphas_cumprod.to(x.device) - at = alphas_cump[t + 1].view(-1, 1, 1, 1) - at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1) + # Notation ( -> + # alpha_prod_t -> α_t + # alpha_prod_t_prev -> α_(t−δ) + # beta_prod_t -> (1 - α_t) + # beta_prod_t_prev -> (1 - α_(t−δ)) + # sample -> x_t + # residual -> e_θ(x_t, t) + # prev_sample -> x_(t−δ) + alpha_prod_t = self.get_alpha_prod(t_orig + 1) + alpha_prod_t_prev = self.get_alpha_prod(t_orig_prev + 1) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev - x_delta = (at_next - at) * ( - (1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et - ) + # corresponds to (α_(t−δ) - α_t) divided by + # denominator of x_t in formula (9) and plus 1 + # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = + # sqrt(α_(t−δ)) / sqrt(α_t)) + sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5) - x_next = x + x_delta - return x_next + # corresponds to denominator of e_θ(x_t, t) in formula (9) + residual_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + ( + alpha_prod_t * beta_prod_t * alpha_prod_t_prev + ) ** (0.5) + + # full formula (9) + prev_sample = sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * residual / residual_denom_coeff + + return prev_sample def __len__(self): return self.config.timesteps diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 6a4556d52e..219151f932 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -20,7 +20,7 @@ import unittest import numpy as np import torch -from diffusers import DDIMScheduler, DDPMScheduler +from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler torch.backends.cuda.matmul.allow_tf32 = False @@ -90,10 +90,10 @@ class SchedulerCommonTest(unittest.TestCase): kwargs.update(forward_kwargs) for scheduler_class in self.scheduler_classes: - scheduler_class = self.scheduler_classes[0] image = self.dummy_image residual = 0.1 * image + scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) @@ -159,7 +159,7 @@ class SchedulerCommonTest(unittest.TestCase): output = scheduler.step(residual, image, 1, **kwargs) output_pt = scheduler_pt.step(residual_pt, image_pt, 1, **kwargs) - assert np.sum(np.abs(output - output_pt.numpy())) < 1e-5, "Scheduler outputs are not identical" + assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" class DDPMSchedulerTest(SchedulerCommonTest): @@ -237,8 +237,8 @@ class DDPMSchedulerTest(SchedulerCommonTest): result_sum = np.sum(np.abs(image)) result_mean = np.mean(np.abs(image)) - assert result_sum.item() - 732.9947 < 1e-3 - assert result_mean.item() - 0.9544 < 1e-3 + assert abs(result_sum.item() - 732.9947) < 1e-2 + assert abs(result_mean.item() - 0.9544) < 1e-3 class DDIMSchedulerTest(SchedulerCommonTest): @@ -325,5 +325,153 @@ class DDIMSchedulerTest(SchedulerCommonTest): result_sum = np.sum(np.abs(image)) result_mean = np.mean(np.abs(image)) - assert result_sum.item() - 270.6214 < 1e-3 - assert result_mean.item() - 0.3524 < 1e-3 + assert abs(result_sum.item() - 270.6214) < 1e-2 + assert abs(result_mean.item() - 0.3524) < 1e-3 + + +class PNDMSchedulerTest(SchedulerCommonTest): + scheduler_classes = (PNDMScheduler,) + forward_default_kwargs = (("num_inference_steps", 50),) + + def get_scheduler_config(self, **kwargs): + config = { + "timesteps": 1000, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + } + + config.update(**kwargs) + return config + + def check_over_configs_pmls(self, time_step=0, **config): + kwargs = dict(self.forward_default_kwargs) + image = self.dummy_image + residual = 0.1 * image + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + + for scheduler_class in self.scheduler_classes: + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(**config) + scheduler = scheduler_class(**scheduler_config) + # copy over dummy past residuals + scheduler.ets = dummy_past_residuals[:] + scheduler.set_plms_mode() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_config(tmpdirname) + # copy over dummy past residuals + new_scheduler.ets = dummy_past_residuals[:] + new_scheduler.set_plms_mode() + + output = scheduler.step(residual, image, time_step, **kwargs) + new_output = new_scheduler.step(residual, image, time_step, **kwargs) + + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def check_over_forward_pmls(self, time_step=0, **forward_kwargs): + kwargs = dict(self.forward_default_kwargs) + kwargs.update(forward_kwargs) + image = self.dummy_image + residual = 0.1 * image + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + + for scheduler_class in self.scheduler_classes: + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + # copy over dummy past residuals + scheduler.ets = dummy_past_residuals[:] + scheduler.set_plms_mode() + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_config(tmpdirname) + new_scheduler = scheduler_class.from_config(tmpdirname) + # copy over dummy past residuals + new_scheduler.ets = dummy_past_residuals[:] + new_scheduler.set_plms_mode() + + output = scheduler.step(residual, image, time_step, **kwargs) + new_output = new_scheduler.step(residual, image, time_step, **kwargs) + + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_timesteps(self): + for timesteps in [100, 1000]: + self.check_over_configs(timesteps=timesteps) + + def test_timesteps_pmls(self): + for timesteps in [100, 1000]: + self.check_over_configs_pmls(timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_betas_pmls(self): + for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): + self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_schedules_pmls(self): + for schedule in ["linear", "squaredcos_cap_v2"]: + self.check_over_configs(beta_schedule=schedule) + + def test_time_indices(self): + for t in [1, 5, 10]: + self.check_over_forward(time_step=t) + + def test_time_indices_pmls(self): + for t in [1, 5, 10]: + self.check_over_forward_pmls(time_step=t) + + def test_inference_steps(self): + for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): + self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + + def test_inference_steps_pmls(self): + for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): + self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps) + + def test_inference_pmls_no_past_residuals(self): + with self.assertRaises(ValueError): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_plms_mode() + + scheduler.step(self.dummy_image, self.dummy_image, 1, 50) + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + image = self.dummy_image_deter + + prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps) + for t in range(len(prk_time_steps)): + t_orig = prk_time_steps[t] + residual = model(image, t_orig) + + image = scheduler.step_prk(residual, image, t, num_inference_steps) + + timesteps = scheduler.get_time_steps(num_inference_steps) + for t in range(len(timesteps)): + t_orig = timesteps[t] + residual = model(image, t_orig) + + image = scheduler.step_plms(residual, image, t, num_inference_steps) + + result_sum = np.sum(np.abs(image)) + result_mean = np.mean(np.abs(image)) + + assert abs(result_sum.item() - 199.1169) < 1e-2 + assert abs(result_mean.item() - 0.2593) < 1e-3