1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Scheduler] introduce sigma schedule. (#7649)

* introduce sigma schedule.

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* address yiyi

* update docstrings.

* implement the schedule for EDMDPMSolverMultistepScheduler

---------

Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Sayak Paul
2024-04-27 07:40:35 +05:30
committed by GitHub
parent 9d16daaf64
commit 56bd7e67c2
2 changed files with 59 additions and 9 deletions

View File

@@ -14,6 +14,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
import math
from typing import List, Optional, Tuple, Union
import numpy as np
@@ -44,6 +45,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
range is [0.2, 80.0].
sigma_data (`float`, *optional*, defaults to 0.5):
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
sigma_schedule (`str`, *optional*, defaults to `karras`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
solver_order (`int`, defaults to 2):
@@ -89,6 +94,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sigma_min: float = 0.002,
sigma_max: float = 80.0,
sigma_data: float = 0.5,
sigma_schedule: str = "karras",
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
rho: float = 7.0,
@@ -121,7 +127,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
)
ramp = torch.linspace(0, 1, num_train_timesteps)
sigmas = self._compute_sigmas(ramp)
if sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
self.timesteps = self.precondition_noise(sigmas)
self.sigmas = self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
@@ -236,7 +246,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = num_inference_steps
ramp = np.linspace(0, 1, self.num_inference_steps)
sigmas = self._compute_sigmas(ramp)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)
@@ -262,10 +275,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -273,6 +285,18 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
"""Implementation closely follows k-diffusion.
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
return sigmas
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@@ -65,6 +66,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
range is [0.2, 80.0].
sigma_data (`float`, *optional*, defaults to 0.5):
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
sigma_schedule (`str`, *optional*, defaults to `karras`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
prediction_type (`str`, defaults to `epsilon`, *optional*):
@@ -84,15 +89,23 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
sigma_min: float = 0.002,
sigma_max: float = 80.0,
sigma_data: float = 0.5,
sigma_schedule: str = "karras",
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
rho: float = 7.0,
):
if sigma_schedule not in ["karras", "exponential"]:
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
# setable values
self.num_inference_steps = None
ramp = torch.linspace(0, 1, num_train_timesteps)
sigmas = self._compute_sigmas(ramp)
if sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
self.timesteps = self.precondition_noise(sigmas)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
@@ -200,7 +213,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = num_inference_steps
ramp = np.linspace(0, 1, self.num_inference_steps)
sigmas = self._compute_sigmas(ramp)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)
@@ -211,9 +227,8 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -221,6 +236,17 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
"""Implementation closely follows k-diffusion.
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep