mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
correct more
This commit is contained in:
@@ -30,9 +30,6 @@ class DDIM(DiffusionPipeline):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
num_trained_timesteps = self.noise_scheduler.num_timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
self.unet.to(torch_device)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
@@ -42,20 +39,11 @@ class DDIM(DiffusionPipeline):
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_image -> f_theta(x_t, t) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
# 1. predict noise residual
|
||||
orig_t = self.noise_scheduler.get_orig_t(t, num_inference_steps)
|
||||
with torch.no_grad():
|
||||
residual = self.unet(image, inference_step_times[t])
|
||||
residual = self.unet(image, orig_t)
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.noise_scheduler.compute_prev_image_step(residual, image, t, num_inference_steps, eta)
|
||||
|
||||
@@ -87,9 +87,14 @@ class DDIMScheduler(nn.Module, ConfigMixin):
|
||||
return torch.tensor(1.0)
|
||||
return self.alphas_cumprod[time_step]
|
||||
|
||||
def get_orig_t(self, t, num_inference_steps):
|
||||
if t < 0:
|
||||
return -1
|
||||
return self.num_timesteps // num_inference_steps * t
|
||||
|
||||
def get_variance(self, t, num_inference_steps):
|
||||
orig_t = (self.num_timesteps // num_inference_steps) * t
|
||||
orig_prev_t = (self.num_timesteps // num_inference_steps) * (t - 1) if t > 0 else -1
|
||||
orig_t = self.get_orig_t(t, num_inference_steps)
|
||||
orig_prev_t = self.get_orig_t(t - 1, num_inference_steps)
|
||||
|
||||
alpha_prod_t = self.get_alpha_prod(orig_t)
|
||||
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t)
|
||||
@@ -113,10 +118,8 @@ class DDIMScheduler(nn.Module, ConfigMixin):
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
|
||||
# 1. get actual t and t-1
|
||||
orig_t = (self.num_timesteps // num_inference_steps) * t
|
||||
orig_prev_t = (self.num_timesteps // num_inference_steps) * (t - 1) if t > 0 else -1
|
||||
# train_step = inference_step_times[t]
|
||||
# prev_train_step = inference_step_times[t - 1] if t > 0 else -1
|
||||
orig_t = self.get_orig_t(t, num_inference_steps)
|
||||
orig_prev_t = self.get_orig_t(t - 1, num_inference_steps)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.get_alpha_prod(orig_t)
|
||||
|
||||
Reference in New Issue
Block a user