mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Bug fix] "previous_timestep()" in DDPM scheduling compatible with "trailing" and "linspace" options (#9384)
* Update scheduling_ddpm.py * fix copies --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: hlky <hlky@hlky.ac>
This commit is contained in:
@@ -548,16 +548,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def previous_timestep(self, timestep):
|
||||
if self.custom_timesteps:
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
||||
if index == self.timesteps.shape[0] - 1:
|
||||
prev_t = torch.tensor(-1)
|
||||
else:
|
||||
prev_t = self.timesteps[index + 1]
|
||||
else:
|
||||
num_inference_steps = (
|
||||
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
)
|
||||
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
||||
|
||||
prev_t = timestep - 1
|
||||
return prev_t
|
||||
|
||||
@@ -639,16 +639,12 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||||
def previous_timestep(self, timestep):
|
||||
if self.custom_timesteps:
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
||||
if index == self.timesteps.shape[0] - 1:
|
||||
prev_t = torch.tensor(-1)
|
||||
else:
|
||||
prev_t = self.timesteps[index + 1]
|
||||
else:
|
||||
num_inference_steps = (
|
||||
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
)
|
||||
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
||||
|
||||
prev_t = timestep - 1
|
||||
return prev_t
|
||||
|
||||
@@ -643,16 +643,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||||
def previous_timestep(self, timestep):
|
||||
if self.custom_timesteps:
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
||||
if index == self.timesteps.shape[0] - 1:
|
||||
prev_t = torch.tensor(-1)
|
||||
else:
|
||||
prev_t = self.timesteps[index + 1]
|
||||
else:
|
||||
num_inference_steps = (
|
||||
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
)
|
||||
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
||||
|
||||
prev_t = timestep - 1
|
||||
return prev_t
|
||||
|
||||
@@ -680,16 +680,12 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||||
def previous_timestep(self, timestep):
|
||||
if self.custom_timesteps:
|
||||
if self.custom_timesteps or self.num_inference_steps:
|
||||
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
||||
if index == self.timesteps.shape[0] - 1:
|
||||
prev_t = torch.tensor(-1)
|
||||
else:
|
||||
prev_t = self.timesteps[index + 1]
|
||||
else:
|
||||
num_inference_steps = (
|
||||
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
)
|
||||
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
||||
|
||||
prev_t = timestep - 1
|
||||
return prev_t
|
||||
|
||||
Reference in New Issue
Block a user