From bff9746da009a36de7717cd2b05cc5117356b99a Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 13 Jun 2022 14:33:48 +0200 Subject: [PATCH] GLIDE + DDIM without artifacts --- src/diffusers/pipelines/pipeline_glide.py | 21 ++++--------------- src/diffusers/schedulers/scheduling_ddim.py | 23 +++++++++++++-------- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 6d2c3982fd..30f2ac2bb5 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -859,9 +859,6 @@ class GLIDE(DiffusionPipeline): nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise - image = image[:1].permute(0, 2, 3, 1) - return image - # 4. Run the upscaling step batch_size = 1 image = image[:1] @@ -879,20 +876,10 @@ class GLIDE(DiffusionPipeline): ) image = image.to(torch_device) * upsample_temp - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding - - # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_image -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_image_direction -> "direction pointingc to x_t" - # - pred_prev_image -> "x_t-1" - num_trained_timesteps = self.upscale_noise_scheduler.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale) - self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale) + # adapt the beta schedule to the number of steps + # self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale) for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale): # 1. predict noise residual @@ -903,7 +890,7 @@ class GLIDE(DiffusionPipeline): # 2. predict previous mean of image x_t-1 pred_prev_image = self.upscale_noise_scheduler.step( - noise_residual, image, t, num_inference_steps_upscale, eta + noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True ) # 3. optionally sample variance @@ -917,6 +904,6 @@ class GLIDE(DiffusionPipeline): # 4. set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + variance - image = image.permute(0, 2, 3, 1) + image = image.clamp(-1, 1).permute(0, 2, 3, 1) return image diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 883a358d34..4311db0461 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -69,14 +69,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # # self.register_buffer("log_variance", log_variance.to(torch.float32)) - def rescale_betas(self, num_timesteps): - if self.beta_schedule == "linear": - scale = self.timesteps / num_timesteps - self.betas = linear_beta_schedule( - num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale - ) - self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + # def rescale_betas(self, num_timesteps): + # # GLIDE scaling + # if self.beta_schedule == "linear": + # scale = self.timesteps / num_timesteps + # self.betas = linear_beta_schedule( + # num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale + # ) + # self.alphas = 1.0 - self.betas + # self.alphas_cumprod = np.cumprod(self.alphas, axis=0) def get_alpha(self, time_step): return self.alphas[time_step] @@ -107,7 +108,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): return variance - def step(self, residual, image, t, num_inference_steps, eta): + def step(self, residual, image, t, num_inference_steps, eta, use_clipped_residual=False): # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -141,6 +142,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): variance = self.get_variance(t, num_inference_steps) std_dev_t = eta * variance ** (0.5) + if use_clipped_residual: + # the residual is always re-derived from the clipped x_0 in GLIDE + residual = (image - alpha_prod_t ** (0.5) * pred_original_image) / beta_prod_t ** (0.5) + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual