1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Bug fix] "previous_timestep()" in DDPM scheduling compatible with "trailing" and "linspace" options (#9384)

* Update scheduling_ddpm.py

* fix copies

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
Anand Kumar
2024-12-03 15:57:52 -08:00
committed by GitHub
parent 619b9658e2
commit 5effcd3e64
4 changed files with 8 additions and 24 deletions

View File

@@ -548,16 +548,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return self.config.num_train_timesteps
def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
prev_t = timestep - 1
return prev_t

View File

@@ -639,16 +639,12 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
prev_t = timestep - 1
return prev_t

View File

@@ -643,16 +643,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
prev_t = timestep - 1
return prev_t

View File

@@ -680,16 +680,12 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
if self.custom_timesteps:
if self.custom_timesteps or self.num_inference_steps:
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
if index == self.timesteps.shape[0] - 1:
prev_t = torch.tensor(-1)
else:
prev_t = self.timesteps[index + 1]
else:
num_inference_steps = (
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
)
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
prev_t = timestep - 1
return prev_t