From 559b8cbf469f70e12bffb1607bc189ef5e14651c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Jun 2022 10:17:45 +0000 Subject: [PATCH 1/2] finish pndm --- src/diffusers/pipelines/pipeline_pndm.py | 58 +++----------- src/diffusers/schedulers/scheduling_pndm.py | 86 +++++++++++++-------- 2 files changed, 64 insertions(+), 80 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_pndm.py b/src/diffusers/pipelines/pipeline_pndm.py index d0fe01ec8a..1116b6042a 100644 --- a/src/diffusers/pipelines/pipeline_pndm.py +++ b/src/diffusers/pipelines/pipeline_pndm.py @@ -32,9 +32,6 @@ class PNDM(DiffusionPipeline): if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" - num_trained_timesteps = self.noise_scheduler.timesteps - inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) - self.unet.to(torch_device) # Sample gaussian noise to begin loop @@ -44,55 +41,22 @@ class PNDM(DiffusionPipeline): ) image = image.to(torch_device) - seq = list(inference_step_times) - seq_next = [-1] + list(seq[:-1]) - model = self.unet - - warmup_time_steps = list(reversed([(t + 5) // 10 * 10 for t in range(seq[-4], seq[-1], 5)])) - - cur_residual = 0 + warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps) prev_image = image - ets = [] - for i in range(len(warmup_time_steps)): - t = warmup_time_steps[i] * torch.ones(image.shape[0]) - t_next = (warmup_time_steps[i + 1] if i < len(warmup_time_steps) - 1 else warmup_time_steps[-1]) * torch.ones(image.shape[0]) + for t in tqdm.tqdm(range(len(warmup_time_steps))): + t_orig = warmup_time_steps[t] + residual = self.unet(image, t_orig) - residual = model(image.to("cuda"), t.to("cuda")) - residual = residual.to("cpu") - - if i % 4 == 0: - cur_residual += 1 / 6 * residual - ets.append(residual) + if t % 4 == 0: prev_image = image - elif (i - 1) % 4 == 0: - cur_residual += 1 / 3 * residual - elif (i - 2) % 4 == 0: - cur_residual += 1 / 3 * residual - elif (i - 3) % 4 == 0: - cur_residual += 1 / 6 * residual - residual = cur_residual - cur_residual = 0 - image = image.to("cpu") - t_2 = warmup_time_steps[4 * (i // 4)] * torch.ones(image.shape[0]) - image = self.noise_scheduler.transfer(prev_image.to("cpu"), t_2, t_next, residual) + image = self.noise_scheduler.step_warm_up(residual, prev_image, t, num_inference_steps) - step_idx = len(seq) - 4 - while step_idx >= 0: - i = seq[step_idx] - j = seq_next[step_idx] + timesteps = self.noise_scheduler.get_time_steps(num_inference_steps) + for t in tqdm.tqdm(range(len(timesteps))): + t_orig = timesteps[t] + residual = self.unet(image, t_orig) - t = (torch.ones(image.shape[0]) * i) - t_next = (torch.ones(image.shape[0]) * j) - - residual = model(image.to("cuda"), t.to("cuda")) - residual = residual.to("cpu") - ets.append(residual) - residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) - - img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual) - image = img_next - - step_idx = step_idx - 1 + image = self.noise_scheduler.step(residual, image, t, num_inference_steps) return image diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index d69e4a5188..ad639a90f1 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -55,6 +55,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) + # hardcode for now + self.pndm_order = 4 + self.cur_residual = 0 + + # running values + self.ets = [] + self.warmup_time_steps = {} + self.time_steps = {} + # self.register_buffer("betas", betas.to(torch.float32)) # self.register_buffer("alphas", alphas.to(torch.float32)) # self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) @@ -83,51 +92,62 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): return self.one return self.alphas_cumprod[time_step] - def step(self, img, t_start, t_end, model, ets): -# img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets) -#def gen_order_4(img, t, t_next, model, alphas_cump, ets): - t_next, t = t_start, t_end + def get_warmup_time_steps(self, num_inference_steps): + if num_inference_steps in self.warmup_time_steps: + return self.warmup_time_steps[num_inference_steps] - noise_ = model(img.to("cuda"), t.to("cuda")) - noise_ = noise_.to("cpu") + inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) - t_list = [t, (t+t_next)/2, t_next] - if len(ets) > 2: - ets.append(noise_) - noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) - else: - noise = self.runge_kutta(img, t_list, model, ets, noise_) + warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order) + self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1])) - img_next = self.transfer(img.to("cpu"), t, t_next, noise) - return img_next, ets + return self.warmup_time_steps[num_inference_steps] - def runge_kutta(self, x, t_list, model, ets, noise_): - model = model.to("cuda") - x = x.to("cpu") + def get_time_steps(self, num_inference_steps): + if num_inference_steps in self.time_steps: + return self.time_steps[num_inference_steps] - e_1 = noise_ - ets.append(e_1) - x_2 = self.transfer(x, t_list[0], t_list[1], e_1) + inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps)) + self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3])) - e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")) - e_2 = e_2.to("cpu") - x_3 = self.transfer(x, t_list[0], t_list[1], e_2) + return self.time_steps[num_inference_steps] - e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")) - e_3 = e_3.to("cpu") - x_4 = self.transfer(x, t_list[0], t_list[2], e_3) + def step_warm_up(self, residual, image, t, num_inference_steps): + warmup_time_steps = self.get_warmup_time_steps(num_inference_steps) - e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")) - e_4 = e_4.to("cpu") + t_prev = warmup_time_steps[t // 4 * 4] + t_next = warmup_time_steps[min(t + 1, len(warmup_time_steps) - 1)] - et = (1 / 6) * (e_1 + 2 * e_2 + 2 * e_3 + e_4) + if t % 4 == 0: + self.cur_residual += 1 / 6 * residual + self.ets.append(residual) + elif (t - 1) % 4 == 0: + self.cur_residual += 1 / 3 * residual + elif (t - 2) % 4 == 0: + self.cur_residual += 1 / 3 * residual + elif (t - 3) % 4 == 0: + residual = self.cur_residual + 1 / 6 * residual + self.cur_residual = 0 - return et + return self.transfer(image, t_prev, t_next, residual) + + def step(self, residual, image, t, num_inference_steps): + timesteps = self.get_time_steps(num_inference_steps) + + t_prev = timesteps[t] + t_next = timesteps[min(t + 1, len(timesteps) - 1)] + self.ets.append(residual) + + residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4]) + + return self.transfer(image, t_prev, t_next, residual) def transfer(self, x, t, t_next, et): - alphas_cump = self.alphas_cumprod - at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1) - at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1) + # TODO(Patrick): clean up to be compatible with numpy and give better names + + alphas_cump = self.alphas_cumprod.to(x.device) + at = alphas_cump[t + 1].view(-1, 1, 1, 1) + at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1) x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et) From f0a99e76846537525cf509d7e2b16b3b91f46de7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 14 Jun 2022 10:22:53 +0000 Subject: [PATCH 2/2] finish --- src/diffusers/schedulers/scheduling_pndm.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index ad639a90f1..fa1c9ca56d 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -55,32 +55,15 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.set_format(tensor_format=tensor_format) - # hardcode for now + # for now we only support F-PNDM, i.e. the runge-kutta method self.pndm_order = 4 - self.cur_residual = 0 # running values + self.cur_residual = 0 self.ets = [] self.warmup_time_steps = {} self.time_steps = {} - # self.register_buffer("betas", betas.to(torch.float32)) - # self.register_buffer("alphas", alphas.to(torch.float32)) - # self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32)) - - # alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0) - # TODO(PVP) - check how much of these is actually necessary! - # LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ... - # https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246 - # variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) - # if variance_type == "fixed_small": - # log_variance = torch.log(variance.clamp(min=1e-20)) - # elif variance_type == "fixed_large": - # log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0)) - # - # - # self.register_buffer("log_variance", log_variance.to(torch.float32)) - def get_alpha(self, time_step): return self.alphas[time_step]