mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add step_index and clear noise_sampler at begining of each loop (#5024)
Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
@@ -199,6 +199,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
self.noise_sampler = None
|
||||
self.noise_sampler_seed = noise_sampler_seed
|
||||
self._step_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
@@ -219,6 +220,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
if len(index_candidates) > 1:
|
||||
step_index = index_candidates[1]
|
||||
else:
|
||||
step_index = index_candidates[0]
|
||||
|
||||
self._step_index = step_index.item()
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
@@ -227,6 +246,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -246,9 +272,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma
|
||||
sample = sample / ((sigma_input**2 + 1) ** 0.5)
|
||||
return sample
|
||||
@@ -321,6 +348,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sample = None
|
||||
self.mid_point_sigma = None
|
||||
|
||||
self._step_index = None
|
||||
self.noise_sampler = None
|
||||
|
||||
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
|
||||
# we need an index counter
|
||||
self._index_counter = defaultdict(int)
|
||||
@@ -411,7 +441,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# advance index counter by 1
|
||||
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
|
||||
@@ -430,12 +461,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
return _sigma.log().neg()
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
else:
|
||||
# 2nd order
|
||||
sigma = self.sigmas[step_index - 1]
|
||||
sigma_next = self.sigmas[step_index]
|
||||
sigma = self.sigmas[self.step_index - 1]
|
||||
sigma_next = self.sigmas[self.step_index]
|
||||
|
||||
# Set the midpoint and step size for the current step
|
||||
midpoint_ratio = 0.5
|
||||
@@ -488,6 +519,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sample = None
|
||||
self.mid_point_sigma = None
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user