mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
rename image to sample in schedulers
This commit is contained in:
@@ -28,7 +28,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_schedule="linear",
|
||||
trained_betas=None,
|
||||
timestep_values=None,
|
||||
clip_predicted_image=True,
|
||||
clip_predicted_sample=True,
|
||||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -40,7 +40,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
|
||||
self.clip_image = clip_predicted_image
|
||||
self.clip_sample = clip_predicted_sample
|
||||
|
||||
if beta_schedule == "linear":
|
||||
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
@@ -111,17 +111,17 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def step(self, residual, image, t, num_inference_steps, eta, use_clipped_residual=False):
|
||||
def step(self, residual, sample, t, num_inference_steps, eta, use_clipped_residual=False):
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_image -> f_theta(x_t, t) or x_0
|
||||
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
# - pred_sample_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_sample -> "x_t-1"
|
||||
|
||||
# 1. get actual t and t-1
|
||||
orig_t = self.get_orig_t(t, num_inference_steps)
|
||||
@@ -132,13 +132,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
alpha_prod_t_prev = self.get_alpha_prod(orig_prev_t)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original image from predicted noise also called
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
if self.clip_image:
|
||||
pred_original_image = self.clip(pred_original_image, -1, 1)
|
||||
if self.clip_sample:
|
||||
pred_original_sample = self.clip(pred_original_sample, -1, 1)
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
@@ -147,15 +147,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if use_clipped_residual:
|
||||
# the residual is always re-derived from the clipped x_0 in GLIDE
|
||||
residual = (image - alpha_prod_t ** (0.5) * pred_original_image) / beta_prod_t ** (0.5)
|
||||
residual = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
|
||||
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_prev_image = alpha_prod_t_prev ** (0.5) * pred_original_image + pred_image_direction
|
||||
pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
|
||||
return pred_prev_image
|
||||
return pred_prev_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.timesteps
|
||||
|
||||
@@ -29,7 +29,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas=None,
|
||||
timestep_values=None,
|
||||
variance_type="fixed_small",
|
||||
clip_predicted_image=True,
|
||||
clip_predicted_sample=True,
|
||||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -41,11 +41,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas=trained_betas,
|
||||
timestep_values=timestep_values,
|
||||
variance_type=variance_type,
|
||||
clip_predicted_image=clip_predicted_image,
|
||||
clip_predicted_sample=clip_predicted_sample,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
|
||||
self.clip_image = clip_predicted_image
|
||||
self.clip_sample = clip_predicted_sample
|
||||
self.variance_type = variance_type
|
||||
|
||||
if trained_betas is not None:
|
||||
@@ -100,8 +100,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
alpha_prod_t_prev = self.get_alpha_prod(t - 1)
|
||||
|
||||
# For t > 0, compute predicted variance βt (see formala (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
# and sample from it to get previous image
|
||||
# x_{t-1} ~ N(pred_prev_image, variance) == add variane to pred_image
|
||||
# and sample from it to get previous sample
|
||||
# x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.get_beta(t)
|
||||
|
||||
# hacks - were probs added for training stability
|
||||
@@ -112,37 +112,37 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def step(self, residual, image, t):
|
||||
def step(self, residual, sample, t):
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.get_alpha_prod(t)
|
||||
alpha_prod_t_prev = self.get_alpha_prod(t - 1)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
# 2. compute predicted original image from predicted noise also called
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_image = (image - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * residual) / alpha_prod_t ** (0.5)
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
if self.clip_predicted_image:
|
||||
pred_original_image = self.clip(pred_original_image, -1, 1)
|
||||
if self.clip_predicted_sample:
|
||||
pred_original_sample = self.clip(pred_original_sample, -1, 1)
|
||||
|
||||
# 4. Compute coefficients for pred_original_image x_0 and current image x_t
|
||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_image_coeff = (alpha_prod_t_prev ** (0.5) * self.get_beta(t)) / beta_prod_t
|
||||
current_image_coeff = self.get_alpha(t) ** (0.5) * beta_prod_t_prev / beta_prod_t
|
||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.get_beta(t)) / beta_prod_t
|
||||
current_sample_coeff = self.get_alpha(t) ** (0.5) * beta_prod_t_prev / beta_prod_t
|
||||
|
||||
# 5. Compute predicted previous image µ_t
|
||||
# 5. Compute predicted previous sample µ_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_prev_image = pred_original_image_coeff * pred_original_image + current_image_coeff * image
|
||||
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
||||
|
||||
return pred_prev_image
|
||||
return pred_prev_sample
|
||||
|
||||
def forward_step(self, original_image, noise, t):
|
||||
def forward_step(self, original_sample, noise, t):
|
||||
sqrt_alpha_prod = self.get_alpha_prod(t) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = (1 - self.get_alpha_prod(t)) ** 0.5
|
||||
noisy_image = sqrt_alpha_prod * original_image + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_image
|
||||
noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_sample
|
||||
|
||||
def __len__(self):
|
||||
return self.timesteps
|
||||
|
||||
@@ -62,7 +62,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# running values
|
||||
self.cur_residual = 0
|
||||
self.cur_image = None
|
||||
self.cur_sample = None
|
||||
self.ets = []
|
||||
self.warmup_time_steps = {}
|
||||
self.time_steps = {}
|
||||
@@ -100,7 +100,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return self.time_steps[num_inference_steps]
|
||||
|
||||
def step_prk(self, residual, image, t, num_inference_steps):
|
||||
def step_prk(self, residual, sample, t, num_inference_steps):
|
||||
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
|
||||
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
|
||||
|
||||
@@ -110,7 +110,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
if t % 4 == 0:
|
||||
self.cur_residual += 1 / 6 * residual
|
||||
self.ets.append(residual)
|
||||
self.cur_image = image
|
||||
self.cur_sample = sample
|
||||
elif (t - 1) % 4 == 0:
|
||||
self.cur_residual += 1 / 3 * residual
|
||||
elif (t - 2) % 4 == 0:
|
||||
@@ -119,9 +119,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
residual = self.cur_residual + 1 / 6 * residual
|
||||
self.cur_residual = 0
|
||||
|
||||
return self.transfer(self.cur_image, t_prev, t_next, residual)
|
||||
return self.transfer(self.cur_sample, t_prev, t_next, residual)
|
||||
|
||||
def step_plms(self, residual, image, t, num_inference_steps):
|
||||
def step_plms(self, residual, sample, t, num_inference_steps):
|
||||
timesteps = self.get_time_steps(num_inference_steps)
|
||||
|
||||
t_prev = timesteps[t]
|
||||
@@ -130,7 +130,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
|
||||
|
||||
return self.transfer(image, t_prev, t_next, residual)
|
||||
return self.transfer(sample, t_prev, t_next, residual)
|
||||
|
||||
def transfer(self, x, t, t_next, et):
|
||||
# TODO(Patrick): clean up to be compatible with numpy and give better names
|
||||
|
||||
Reference in New Issue
Block a user