1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

improve ddim comments

This commit is contained in:
Patrick von Platen
2022-06-09 16:31:02 +02:00
parent f035fbfba7
commit 8841d0d1a9

View File

@@ -34,49 +34,68 @@ class DDIM(DiffusionPipeline):
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
# Sample gaussian noise to begin loop
image = self.noise_scheduler.sample_noise(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
device=torch_device,
generator=generator,
)
# See formulas (9), (10) and (7) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# get actual t and t-1
# 1. predict noise residual
with torch.no_grad():
pred_noise_t = self.unet(image, inference_step_times[t])
# 2. get actual t and t-1
train_step = inference_step_times[t]
prev_train_step = inference_step_times[t - 1] if t > 0 else -1
# compute alphas
# 3. compute alphas, betas
alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt()
alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt()
beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
beta_prod_t = (1 - alpha_prod_t)
beta_prod_t_prev = (1 - alpha_prod_t_prev)
# compute relevant coefficients
coeff_1 = (
(alpha_prod_t_prev - alpha_prod_t).sqrt()
* alpha_prod_t_prev_rsqrt
* beta_prod_t_prev_sqrt
/ beta_prod_t_sqrt
* eta
)
coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1**2).sqrt()
# 4. Compute predicted previous image from predicted noise
# model forward
with torch.no_grad():
noise_residual = self.unet(image, train_step)
# First: compute predicted original image from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_image = (image - beta_prod_t.sqrt() * pred_noise_t) / alpha_prod_t.sqrt()
# predict mean of prev image
pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual)
pred_mean = torch.clamp(pred_mean, -1, 1)
pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual
# Second: Clip "predicted x_0"
pred_original_image = torch.clamp(pred_original_image, -1, 1)
# if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
# Third: Compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
std_dev_t = (beta_prod_t_prev / beta_prod_t).sqrt() * (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
std_dev_t = eta * std_dev_t
# Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
# Fifth: Compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
# 5. Sample x_t-1 image optionally if η > 0.0 by adding noise to pred_prev_image
# Note: eta = 1.0 essentially corresponds to DDPM
if eta > 0.0:
noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
image = pred_mean + coeff_1 * noise
prev_image = pred_prev_image + std_dev_t * noise
else:
image = pred_mean
prev_image = pred_prev_image
# 6. Set current image to prev_image: x_t -> x_t-1
image = prev_image
return image