mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
GLIDE + DDIM without artifacts
This commit is contained in:
@@ -859,9 +859,6 @@ class GLIDE(DiffusionPipeline):
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
|
||||
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
|
||||
|
||||
image = image[:1].permute(0, 2, 3, 1)
|
||||
return image
|
||||
|
||||
# 4. Run the upscaling step
|
||||
batch_size = 1
|
||||
image = image[:1]
|
||||
@@ -879,20 +876,10 @@ class GLIDE(DiffusionPipeline):
|
||||
)
|
||||
image = image.to(torch_device) * upsample_temp
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_image -> f_theta(x_t, t) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
|
||||
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
|
||||
self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
|
||||
# adapt the beta schedule to the number of steps
|
||||
# self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
|
||||
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
||||
# 1. predict noise residual
|
||||
@@ -903,7 +890,7 @@ class GLIDE(DiffusionPipeline):
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.upscale_noise_scheduler.step(
|
||||
noise_residual, image, t, num_inference_steps_upscale, eta
|
||||
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
|
||||
)
|
||||
|
||||
# 3. optionally sample variance
|
||||
@@ -917,6 +904,6 @@ class GLIDE(DiffusionPipeline):
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
image = image.clamp(-1, 1).permute(0, 2, 3, 1)
|
||||
|
||||
return image
|
||||
|
||||
@@ -69,14 +69,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
#
|
||||
# self.register_buffer("log_variance", log_variance.to(torch.float32))
|
||||
|
||||
def rescale_betas(self, num_timesteps):
|
||||
if self.beta_schedule == "linear":
|
||||
scale = self.timesteps / num_timesteps
|
||||
self.betas = linear_beta_schedule(
|
||||
num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
|
||||
)
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
# def rescale_betas(self, num_timesteps):
|
||||
# # GLIDE scaling
|
||||
# if self.beta_schedule == "linear":
|
||||
# scale = self.timesteps / num_timesteps
|
||||
# self.betas = linear_beta_schedule(
|
||||
# num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
|
||||
# )
|
||||
# self.alphas = 1.0 - self.betas
|
||||
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
|
||||
def get_alpha(self, time_step):
|
||||
return self.alphas[time_step]
|
||||
@@ -107,7 +108,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def step(self, residual, image, t, num_inference_steps, eta):
|
||||
def step(self, residual, image, t, num_inference_steps, eta, use_clipped_residual=False):
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
@@ -141,6 +142,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance = self.get_variance(t, num_inference_steps)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
if use_clipped_residual:
|
||||
# the residual is always re-derived from the clipped x_0 in GLIDE
|
||||
residual = (image - alpha_prod_t ** (0.5) * pred_original_image) / beta_prod_t ** (0.5)
|
||||
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
|
||||
|
||||
|
||||
Reference in New Issue
Block a user