From 5effcd3e6461490ab27171d7c576d0ea4909a4a8 Mon Sep 17 00:00:00 2001 From: Anand Kumar <63339285+AnandK27@users.noreply.github.com> Date: Tue, 3 Dec 2024 15:57:52 -0800 Subject: [PATCH] [Bug fix] "previous_timestep()" in DDPM scheduling compatible with "trailing" and "linspace" options (#9384) * Update scheduling_ddpm.py * fix copies --------- Co-authored-by: YiYi Xu Co-authored-by: hlky --- src/diffusers/schedulers/scheduling_ddpm.py | 8 ++------ src/diffusers/schedulers/scheduling_ddpm_parallel.py | 8 ++------ src/diffusers/schedulers/scheduling_lcm.py | 8 ++------ src/diffusers/schedulers/scheduling_tcd.py | 8 ++------ 4 files changed, 8 insertions(+), 24 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 468fdf61a9..eb40d79b9f 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -548,16 +548,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): return self.config.num_train_timesteps def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t diff --git a/src/diffusers/schedulers/scheduling_ddpm_parallel.py b/src/diffusers/schedulers/scheduling_ddpm_parallel.py index f377ee6e8c..20ad7a4c92 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_parallel.py +++ b/src/diffusers/schedulers/scheduling_ddpm_parallel.py @@ -639,16 +639,12 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index f1aa09ab17..686b686f68 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -643,16 +643,12 @@ class LCMScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t diff --git a/src/diffusers/schedulers/scheduling_tcd.py b/src/diffusers/schedulers/scheduling_tcd.py index 580224404c..5d60383142 100644 --- a/src/diffusers/schedulers/scheduling_tcd.py +++ b/src/diffusers/schedulers/scheduling_tcd.py @@ -680,16 +680,12 @@ class TCDScheduler(SchedulerMixin, ConfigMixin): # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep def previous_timestep(self, timestep): - if self.custom_timesteps: + if self.custom_timesteps or self.num_inference_steps: index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.timesteps[index + 1] else: - num_inference_steps = ( - self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps - ) - prev_t = timestep - self.config.num_train_timesteps // num_inference_steps - + prev_t = timestep - 1 return prev_t