From fe4837a96e1ad34ff3f1981e9b92be063a3c8d73 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Thu, 14 Sep 2023 06:48:35 -1000 Subject: [PATCH] add step_index and clear noise_sampler at begining of each loop (#5024) Co-authored-by: yiyixuxu --- .../schedulers/scheduling_dpmsolver_sde.py | 48 ++++++++++++++++--- 1 file changed, 41 insertions(+), 7 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 59bd6ccdfa..d39efbe724 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -199,6 +199,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): self.use_karras_sigmas = use_karras_sigmas self.noise_sampler = None self.noise_sampler_seed = noise_sampler_seed + self._step_index = None # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): @@ -219,6 +220,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): return indices[pos].item() + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(index_candidates) > 1: + step_index = index_candidates[1] + else: + step_index = index_candidates[0] + + self._step_index = step_index.item() + @property def init_noise_sigma(self): # standard deviation of the initial noise distribution @@ -227,6 +246,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): return (self.sigmas.max() ** 2 + 1) ** 0.5 + @property + def step_index(self): + """ + The index counter for current timestep. It will increae 1 after each scheduler step. + """ + return self._step_index + def scale_model_input( self, sample: torch.FloatTensor, @@ -246,9 +272,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): `torch.FloatTensor`: A scaled input sample. """ - step_index = self.index_for_timestep(timestep) + if self.step_index is None: + self._init_step_index(timestep) - sigma = self.sigmas[step_index] + sigma = self.sigmas[self.step_index] sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma sample = sample / ((sigma_input**2 + 1) ** 0.5) return sample @@ -321,6 +348,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): self.sample = None self.mid_point_sigma = None + self._step_index = None + self.noise_sampler = None + # for exp beta schedules, such as the one for `pipeline_shap_e.py` # we need an index counter self._index_counter = defaultdict(int) @@ -411,7 +441,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ - step_index = self.index_for_timestep(timestep) + if self.step_index is None: + self._init_step_index(timestep) # advance index counter by 1 timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep @@ -430,12 +461,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): return _sigma.log().neg() if self.state_in_first_order: - sigma = self.sigmas[step_index] - sigma_next = self.sigmas[step_index + 1] + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] else: # 2nd order - sigma = self.sigmas[step_index - 1] - sigma_next = self.sigmas[step_index] + sigma = self.sigmas[self.step_index - 1] + sigma_next = self.sigmas[self.step_index] # Set the midpoint and step size for the current step midpoint_ratio = 0.5 @@ -488,6 +519,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): self.sample = None self.mid_point_sigma = None + # upon completion increase step index by one + self._step_index += 1 + if not return_dict: return (prev_sample,)