diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 89d90ba60a..5945b0c1ea 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -129,6 +129,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): an offset added to the inference steps. You can use a combination of `offset=1` and `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in stable diffusion. + prediction_type (`Literal["epsilon", "sample", "v"]`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 + https://imagen.research.google/video/paper.pdf) """ @@ -181,8 +185,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # For the final step, there is no previous alphas_cumprod because we are already at 0 # `set_alpha_to_one` decides whether we set this parameter simply to one or # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] - self.final_sigma = torch.tensor(0.0) if set_alpha_to_one else self.sigmas[0] + if set_alpha_to_one: + self.final_alpha_cumprod = torch.tensor(1.0) + self.final_sigma = torch.tensor(0.0) # TODO rename set_alpha_to_one for something general with sigma=0 + else: + self.final_alpha_cumprod = self.alphas_cumprod[0] + self.final_sigma = self.sigmas[0] if prediction_type == "v" else None # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d403c4f595..9931d8c143 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -114,6 +114,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v` (see section 2.4 https://imagen.research.google/video/paper.pdf) + predict_epsilon (`bool`, default `True`): + depreciated flag (removing v0.10.0) for epsilon vs. direct sample prediction. """ _compatible_classes = [ @@ -136,6 +138,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: Literal["epsilon", "sample", "v"] = "epsilon", + predict_epsilon: bool = True, ): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) @@ -265,8 +268,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): if self.variance_type == "v_diffusion": assert self.prediction_type == "v", "Need to use v prediction with v_diffusion" message = ( - "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" - " DDPMScheduler.from_config(, predict_epsilon=True)`." + "Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler =" + " DDPMScheduler.from_config(, prediction_type=epsilon)`." ) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: @@ -293,11 +296,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): sample * self.sqrt_alphas_cumprod[timestep] - model_output * self.sqrt_one_minus_alphas_cumprod[timestep] ) + + # not check on predict_epsilon for depreciation flag above + elif self.prediction_type == "sample" or not self.config.predict_epsilon: + pred_original_sample = model_output + elif self.prediction_type == "epsilon" or self.config.predict_epsilon: pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) - elif self.prediction_type == "sample": - pred_original_sample = model_output else: raise ValueError( f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`" diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index ab52171511..5a19d1059c 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -393,9 +393,9 @@ class DDPMSchedulerTest(SchedulerCommonTest): for clip_sample in [True, False]: self.check_over_configs(clip_sample=clip_sample) - def test_predict_epsilon(self): - for predict_epsilon in [True, False]: - self.check_over_configs(predict_epsilon=predict_epsilon) + def test_prediction_type(self): + for prediction_type in ["epsilon", "sample", "v"]: + self.check_over_configs(prediction_type=prediction_type) def test_deprecated_epsilon(self): deprecate("remove this test", "0.10.0", "remove") @@ -407,7 +407,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): time_step = 4 scheduler = scheduler_class(**scheduler_config) - scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config) + scheduler_eps = scheduler_class(prediction_type="sample", **scheduler_config) kwargs = {} if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):