diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 2f21faa2bf..92975a3ffd 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -15,7 +15,6 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit -import pdb from typing import Union import numpy as np @@ -55,39 +54,35 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): # self.num_inference_steps = None self.timesteps = None - self.set_sigmas(self.num_train_timesteps) + self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps) self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps): + def set_timesteps(self, num_inference_steps, sampling_eps=None): + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps tensor_format = getattr(self, "tensor_format", "pt") if tensor_format == "np": - self.timesteps = np.linspace(1, self.config.sampling_eps, num_inference_steps) + self.timesteps = np.linspace(1, sampling_eps, num_inference_steps) elif tensor_format == "pt": - self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps) + self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps) else: raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - def set_sigmas(self, num_inference_steps): + def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None): + sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min + sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max + sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps if self.timesteps is None: - self.set_timesteps(num_inference_steps) + self.set_timesteps(num_inference_steps, sampling_eps) tensor_format = getattr(self, "tensor_format", "pt") if tensor_format == "np": - self.discrete_sigmas = np.exp( - np.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) - ) - self.sigmas = np.array( - [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] - ) + self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) + self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) elif tensor_format == "pt": - self.discrete_sigmas = torch.exp( - torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps) - ) - self.sigmas = torch.tensor( - [self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps] - ) + self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps)) + self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps]) else: raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")