From 64ed3eadaf373b3fcd8995c38486797392be64c5 Mon Sep 17 00:00:00 2001 From: Clarence Chen Date: Thu, 9 Mar 2023 12:22:26 -0800 Subject: [PATCH] Fix old bug introduced when prediction type is "sample" --- src/diffusers/schedulers/scheduling_ddim_inverse.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index 392ebd1c5c..113afcef55 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -224,12 +224,13 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output elif self.config.prediction_type == "sample": pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output - # predict V - model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" @@ -237,7 +238,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): ) # 4. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf - pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon # 5. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction