1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Fix old bug introduced when prediction type is "sample"

This commit is contained in:
Clarence Chen
2023-03-09 12:22:26 -08:00
parent 3d9346609e
commit 64ed3eadaf

View File

@@ -224,12 +224,13 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
@@ -237,7 +238,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
)
# 4. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
# 5. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction