From f035fbfba7a3d38fbb3f6d7cd68ceb4a9b11307d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 9 Jun 2022 16:30:56 +0200 Subject: [PATCH] improve ddim comments --- models/vision/ddpm/modeling_ddpm.py | 46 ++++++++++++++--------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index 584a61454c..f041235fde 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -30,43 +30,43 @@ class DDPM(DiffusionPipeline): torch_device = "cuda" if torch.cuda.is_available() else "cpu" self.unet.to(torch_device) - # 1. Sample gaussian noise + + # Sample gaussian noise to begin loop image = self.noise_scheduler.sample_noise( (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator, ) - for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): - # i) define coefficients for time step t - clipped_image_coeff = 1 / torch.sqrt(self.noise_scheduler.get_alpha_prod(t)) - clipped_noise_coeff = torch.sqrt(1 / self.noise_scheduler.get_alpha_prod(t) - 1) - image_coeff = ( - (1 - self.noise_scheduler.get_alpha_prod(t - 1)) - * torch.sqrt(self.noise_scheduler.get_alpha(t)) - / (1 - self.noise_scheduler.get_alpha_prod(t)) - ) - clipped_coeff = ( - torch.sqrt(self.noise_scheduler.get_alpha_prod(t - 1)) - * self.noise_scheduler.get_beta(t) - / (1 - self.noise_scheduler.get_alpha_prod(t)) - ) - # ii) predict noise residual + for t in tqdm.tqdm(reversed(range(len(self.noise_scheduler))), total=len(self.noise_scheduler)): + # 1. predict noise residual with torch.no_grad(): noise_residual = self.unet(image, t) - # iii) compute predicted image from residual - # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison - pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual - pred_mean = torch.clamp(pred_mean, -1, 1) - prev_image = clipped_coeff * pred_mean + image_coeff * image + # 2. compute alphas, betas + alpha_prod_t = self.noise_scheduler.get_alpha_prod(t) + alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(t - 1) + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev - # iv) sample variance + # 3. compute predicted image from residual + # See 2nd formula at https://github.com/hojonathanho/diffusion/issues/5#issue-896554416 for comparison + # First: Compute inner formula + pred_mean = (1 / alpha_prod_t.sqrt()) * (image - beta_prod_t.sqrt() * noise_residual) + # Second: Clip + pred_mean = torch.clamp(pred_mean, -1, 1) + # Third: Compute outer coefficients + pred_mean_coeff = (alpha_prod_t_prev.sqrt() * self.noise_scheduler.get_beta(t)) / beta_prod_t + image_coeff = (beta_prod_t_prev * self.noise_scheduler.get_alpha(t).sqrt()) / beta_prod_t + # Fourth: Compute outer formula + prev_image = pred_mean_coeff * pred_mean + image_coeff * image + + # 4. sample variance prev_variance = self.noise_scheduler.sample_variance( t, prev_image.shape, device=torch_device, generator=generator ) - # v) sample x_{t-1} ~ N(prev_image, prev_variance) + # 5. sample x_{t-1} ~ N(prev_image, prev_variance) = add variance to predicted image sampled_prev_image = prev_image + prev_variance image = sampled_prev_image