From e222246b4e7b60db7fe5fd27dc187bce446b5b56 Mon Sep 17 00:00:00 2001 From: hlky Date: Wed, 18 Dec 2024 12:22:10 +0000 Subject: [PATCH] Fix sigma_last with use_flow_sigmas (#10267) --- src/diffusers/schedulers/scheduling_deis_multistep.py | 1 + .../schedulers/scheduling_dpmsolver_multistep_inverse.py | 3 +++ src/diffusers/schedulers/scheduling_sasolver.py | 1 + src/diffusers/schedulers/scheduling_unipc_multistep.py | 9 +++++++++ 4 files changed, 14 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 3350c3373e..6a653f183b 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -289,6 +289,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): sigmas = 1.0 - alphas sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 19399a724a..971817f7b7 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -291,14 +291,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) elif self.config.use_beta_sigmas: sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) elif self.config.use_flow_sigmas: alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1) sigmas = 1.0 - alphas sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_max = ( diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 41a471275f..d45c93880b 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -318,6 +318,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin): sigmas = 1.0 - alphas sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() + sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index c6434c6f87..0150042630 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -381,6 +381,15 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin): sigmas = 1.0 - alphas sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy() timesteps = (sigmas * self.config.num_train_timesteps).copy() + if self.config.final_sigmas_type == "sigma_min": + sigma_last = sigmas[-1] + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) else: sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) if self.config.final_sigmas_type == "sigma_min":