diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index b5a352c785..8690e9f0b7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -1156,8 +1156,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline): # 7. Denoising loop where we obtain the cross-attention maps. num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order - with self.progress_bar(total=num_inference_steps - 2) as progress_bar: - for i, t in enumerate(timesteps[1:-1]): + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index 7006bd1339..392ebd1c5c 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -122,7 +122,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, - set_alpha_to_one: bool = True, + set_alpha_to_zero: bool = True, steps_offset: int = 0, prediction_type: str = "epsilon", ): @@ -144,11 +144,12 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # At every step in ddim, we are looking into the previous alphas_cumprod - # For the final step, there is no previous alphas_cumprod because we are already at 0 - # `set_alpha_to_one` decides whether we set this parameter simply to one or - # whether we use the final alpha of the "non-previous" one. - self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + # At every step in inverted ddim, we are looking into the next alphas_cumprod + # For the final step, there is no next alphas_cumprod, and the index is out of bounds + # `set_alpha_to_zero` decides whether we set this parameter simply to zero + # in this case, self.step() just normalizes output by self.config.prediction_type + # or whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(0.0) if set_alpha_to_zero else self.alphas_cumprod[-1] # standard deviation of the initial noise distribution self.init_noise_sigma = 1.0 @@ -157,6 +158,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): self.num_inference_steps = None self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps).copy().astype(np.int64)) + # Copy from diffusers.schedulers.scheduling_ddim.DDIMScheduler.scale_model_input def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -205,23 +207,44 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin): variance_noise: Optional[torch.FloatTensor] = None, return_dict: bool = True, ) -> Union[DDIMSchedulerOutput, Tuple]: - e_t = model_output - - x = sample + # 1. get previous step value (=t+1) prev_timestep = timestep + self.config.num_train_timesteps // self.num_inference_steps - a_t = self.alphas_cumprod[timestep - 1] - a_prev = self.alphas_cumprod[prev_timestep - 1] if prev_timestep >= 0 else self.final_alpha_cumprod + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep < self.config.num_train_timesteps + else self.final_alpha_cumprod + ) - pred_x0 = (x - (1 - a_t) ** 0.5 * e_t) / a_t.sqrt() + beta_prod_t = 1 - alpha_prod_t - dir_xt = (1.0 - a_prev).sqrt() * e_t + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) - prev_sample = a_prev.sqrt() * pred_x0 + dir_xt + # 4. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output + + # 5. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction if not return_dict: - return (prev_sample, pred_x0) - return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0) + return (prev_sample, pred_original_sample) + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) def __len__(self): return self.config.num_train_timesteps