1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix score sde ve scheduler

This commit is contained in:
Patrick von Platen
2022-07-20 21:02:40 +00:00
parent 919e27d357
commit 760dcb1ffc

View File

@@ -15,7 +15,6 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import pdb
from typing import Union
import numpy as np
@@ -55,39 +54,35 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# self.num_inference_steps = None
self.timesteps = None
self.set_sigmas(self.num_train_timesteps)
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps):
def set_timesteps(self, num_inference_steps, sampling_eps=None):
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
self.timesteps = np.linspace(1, self.config.sampling_eps, num_inference_steps)
self.timesteps = np.linspace(1, sampling_eps, num_inference_steps)
elif tensor_format == "pt":
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_sigmas(self, num_inference_steps):
def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None):
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
if self.timesteps is None:
self.set_timesteps(num_inference_steps)
self.set_timesteps(num_inference_steps, sampling_eps)
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
self.discrete_sigmas = np.exp(
np.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = np.array(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
self.discrete_sigmas = np.exp(np.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
self.sigmas = np.array([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
elif tensor_format == "pt":
self.discrete_sigmas = torch.exp(
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = torch.tensor(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
self.discrete_sigmas = torch.exp(torch.linspace(np.log(sigma_min), np.log(sigma_max), num_inference_steps))
self.sigmas = torch.tensor([sigma_min * (sigma_max / sigma_min) ** t for t in self.timesteps])
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")