1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

correct more

This commit is contained in:
Patrick von Platen
2022-06-10 12:49:40 +00:00
parent a14d774b40
commit 01cf739213
2 changed files with 11 additions and 20 deletions

View File

@@ -30,9 +30,6 @@ class DDIM(DiffusionPipeline):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.num_timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
# Sample gaussian noise to begin loop
@@ -42,20 +39,11 @@ class DDIM(DiffusionPipeline):
generator=generator,
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual
orig_t = self.noise_scheduler.get_orig_t(t, num_inference_steps)
with torch.no_grad():
residual = self.unet(image, inference_step_times[t])
residual = self.unet(image, orig_t)
# 2. predict previous mean of image x_t-1
pred_prev_image = self.noise_scheduler.compute_prev_image_step(residual, image, t, num_inference_steps, eta)

View File

@@ -87,9 +87,14 @@ class DDIMScheduler(nn.Module, ConfigMixin):
return torch.tensor(1.0)
return self.alphas_cumprod[time_step]
def get_orig_t(self, t, num_inference_steps):
if t < 0:
return -1
return self.num_timesteps // num_inference_steps * t
def get_variance(self, t, num_inference_steps):
orig_t = (self.num_timesteps // num_inference_steps) * t
orig_prev_t = (self.num_timesteps // num_inference_steps) * (t - 1) if t > 0 else -1
orig_t = self.get_orig_t(t, num_inference_steps)
orig_prev_t = self.get_orig_t(t - 1, num_inference_steps)
alpha_prod_t = self.get_alpha_prod(orig_t)
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t)
@@ -113,10 +118,8 @@ class DDIMScheduler(nn.Module, ConfigMixin):
# - pred_prev_image -> "x_t-1"
# 1. get actual t and t-1
orig_t = (self.num_timesteps // num_inference_steps) * t
orig_prev_t = (self.num_timesteps // num_inference_steps) * (t - 1) if t > 0 else -1
# train_step = inference_step_times[t]
# prev_train_step = inference_step_times[t - 1] if t > 0 else -1
orig_t = self.get_orig_t(t, num_inference_steps)
orig_prev_t = self.get_orig_t(t - 1, num_inference_steps)
# 2. compute alphas, betas
alpha_prod_t = self.get_alpha_prod(orig_t)