mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix score sde ve scheduler
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
|
||||
import pdb
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
@@ -55,39 +54,35 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
# self.num_inference_steps = None
|
||||
self.timesteps = None
|
||||
|
||||
self.set_sigmas(self.num_train_timesteps)
|
||||
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
def set_timesteps(self, num_inference_steps, sampling_eps=None):
|
||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
self.timesteps = np.linspace(1, self.config.sampling_eps, num_inference_steps)
|
||||
self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
|
||||
elif tensor_format == "pt":
|
||||
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
|
||||
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
|
||||
else:
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def set_sigmas(self, num_inference_steps):
|
||||
def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None):
|
||||
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
|
||||
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
|
||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||
if self.timesteps is None:
|
||||
self.set_timesteps(num_inference_steps)
|
||||
self.set_timesteps(num_inference_steps, sampling_eps)
|
||||
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
if tensor_format == "np":
|
||||
self.discrete_sigmas = np.exp(
|
||||
np.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
|
||||
)
|
||||
self.sigmas = np.array(
|
||||
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
|
||||
)
|
||||
self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
|
||||
self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
|
||||
elif tensor_format == "pt":
|
||||
self.discrete_sigmas = torch.exp(
|
||||
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
|
||||
)
|
||||
self.sigmas = torch.tensor(
|
||||
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
|
||||
)
|
||||
self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
|
||||
self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
|
||||
else:
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user