From 135acd83af86b02c1dfb3bdb5650d19ef10332b2 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 26 Jun 2022 00:56:18 +0000 Subject: [PATCH] fix bug --- src/diffusers/schedulers/scheduling_sde_ve.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 652314b9c9..2456afad7d 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -24,13 +24,12 @@ from .scheduling_utils import SchedulerMixin class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): - def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, N=2, sampling_eps=1e-5, tensor_format="np"): + def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"): super().__init__() self.register_to_config( snr=snr, sigma_min=sigma_min, sigma_max=sigma_max, - N=N, sampling_eps=sampling_eps, ) @@ -54,7 +53,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): def step_pred(self, result, x, t): t = t * torch.ones(x.shape[0], device=x.device) - timestep = (t * (2 - 1)).long() + timestep = (t * (len(self.timesteps) - 1)).long() sigma = self.discrete_sigmas.to(t.device)[timestep] adjacent_sigma = torch.where(