From 01cf7392130a7c6bebf198e7a894e6ef828f01ff Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 10 Jun 2022 12:49:40 +0000 Subject: [PATCH] correct more --- models/vision/ddim/modeling_ddim.py | 16 ++-------------- src/diffusers/schedulers/ddim.py | 15 +++++++++------ 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/models/vision/ddim/modeling_ddim.py b/models/vision/ddim/modeling_ddim.py index 7cf564e2ef..c11b8e4d1d 100644 --- a/models/vision/ddim/modeling_ddim.py +++ b/models/vision/ddim/modeling_ddim.py @@ -30,9 +30,6 @@ class DDIM(DiffusionPipeline): if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" - num_trained_timesteps = self.noise_scheduler.num_timesteps - inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps) - self.unet.to(torch_device) # Sample gaussian noise to begin loop @@ -42,20 +39,11 @@ class DDIM(DiffusionPipeline): generator=generator, ) - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding - - # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_image -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_image_direction -> "direction pointingc to x_t" - # - pred_prev_image -> "x_t-1" for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): # 1. predict noise residual + orig_t = self.noise_scheduler.get_orig_t(t, num_inference_steps) with torch.no_grad(): - residual = self.unet(image, inference_step_times[t]) + residual = self.unet(image, orig_t) # 2. predict previous mean of image x_t-1 pred_prev_image = self.noise_scheduler.compute_prev_image_step(residual, image, t, num_inference_steps, eta) diff --git a/src/diffusers/schedulers/ddim.py b/src/diffusers/schedulers/ddim.py index f9085f72a1..a0f1099e8a 100644 --- a/src/diffusers/schedulers/ddim.py +++ b/src/diffusers/schedulers/ddim.py @@ -87,9 +87,14 @@ class DDIMScheduler(nn.Module, ConfigMixin): return torch.tensor(1.0) return self.alphas_cumprod[time_step] + def get_orig_t(self, t, num_inference_steps): + if t < 0: + return -1 + return self.num_timesteps // num_inference_steps * t + def get_variance(self, t, num_inference_steps): - orig_t = (self.num_timesteps // num_inference_steps) * t - orig_prev_t = (self.num_timesteps // num_inference_steps) * (t - 1) if t > 0 else -1 + orig_t = self.get_orig_t(t, num_inference_steps) + orig_prev_t = self.get_orig_t(t - 1, num_inference_steps) alpha_prod_t = self.get_alpha_prod(orig_t) alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t) @@ -113,10 +118,8 @@ class DDIMScheduler(nn.Module, ConfigMixin): # - pred_prev_image -> "x_t-1" # 1. get actual t and t-1 - orig_t = (self.num_timesteps // num_inference_steps) * t - orig_prev_t = (self.num_timesteps // num_inference_steps) * (t - 1) if t > 0 else -1 -# train_step = inference_step_times[t] -# prev_train_step = inference_step_times[t - 1] if t > 0 else -1 + orig_t = self.get_orig_t(t, num_inference_steps) + orig_prev_t = self.get_orig_t(t - 1, num_inference_steps) # 2. compute alphas, betas alpha_prod_t = self.get_alpha_prod(orig_t)