1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add step_index and clear noise_sampler at begining of each loop (#5024)

Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
YiYi Xu
2023-09-14 06:48:35 -10:00
committed by GitHub
parent 342c5c02c0
commit fe4837a96e

View File

@@ -199,6 +199,7 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.use_karras_sigmas = use_karras_sigmas
self.noise_sampler = None
self.noise_sampler_seed = noise_sampler_seed
self._step_index = None
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
@@ -219,6 +220,24 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return indices[pos].item()
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
if len(index_candidates) > 1:
step_index = index_candidates[1]
else:
step_index = index_candidates[0]
self._step_index = step_index.item()
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
@@ -227,6 +246,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return (self.sigmas.max() ** 2 + 1) ** 0.5
@property
def step_index(self):
"""
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
def scale_model_input(
self,
sample: torch.FloatTensor,
@@ -246,9 +272,10 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
`torch.FloatTensor`:
A scaled input sample.
"""
step_index = self.index_for_timestep(timestep)
if self.step_index is None:
self._init_step_index(timestep)
sigma = self.sigmas[step_index]
sigma = self.sigmas[self.step_index]
sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma
sample = sample / ((sigma_input**2 + 1) ** 0.5)
return sample
@@ -321,6 +348,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
self.mid_point_sigma = None
self._step_index = None
self.noise_sampler = None
# for exp beta schedules, such as the one for `pipeline_shap_e.py`
# we need an index counter
self._index_counter = defaultdict(int)
@@ -411,7 +441,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
step_index = self.index_for_timestep(timestep)
if self.step_index is None:
self._init_step_index(timestep)
# advance index counter by 1
timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep
@@ -430,12 +461,12 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
return _sigma.log().neg()
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
else:
# 2nd order
sigma = self.sigmas[step_index - 1]
sigma_next = self.sigmas[step_index]
sigma = self.sigmas[self.step_index - 1]
sigma_next = self.sigmas[self.step_index]
# Set the midpoint and step size for the current step
midpoint_ratio = 0.5
@@ -488,6 +519,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
self.mid_point_sigma = None
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)