mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Schedulers] Add beta sigmas / beta noise schedule (#9509)
Add beta sigmas / beta noise schedule
This commit is contained in:
@@ -20,11 +20,14 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils import BaseOutput, is_scipy_available, logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
if is_scipy_available():
|
||||
import scipy.stats
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@@ -160,6 +163,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
the sigmas are determined according to a sequence of noise levels {σi}.
|
||||
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
@@ -189,6 +195,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
interpolation_type: str = "linear",
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
@@ -197,8 +204,12 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
||||
):
|
||||
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -241,6 +252,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = False
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
self.use_exponential_sigmas = use_exponential_sigmas
|
||||
self.use_beta_sigmas = use_beta_sigmas
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
@@ -340,6 +352,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
|
||||
if timesteps is not None and self.config.use_exponential_sigmas:
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
|
||||
if timesteps is not None and self.config.use_beta_sigmas:
|
||||
raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.")
|
||||
if (
|
||||
timesteps is not None
|
||||
and self.config.timestep_type == "continuous"
|
||||
@@ -408,6 +422,10 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
elif self.config.final_sigmas_type == "zero":
|
||||
@@ -502,6 +520,37 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
|
||||
return sigmas
|
||||
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = torch.Tensor(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
Reference in New Issue
Block a user