mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Scheduler] introduce sigma schedule. (#7649)
* introduce sigma schedule. Co-authored-by: Suraj Patil <surajp815@gmail.com> * address yiyi * update docstrings. * implement the schedule for EDMDPMSolverMultistepScheduler --------- Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -44,6 +45,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
range is [0.2, 80.0].
|
||||
sigma_data (`float`, *optional*, defaults to 0.5):
|
||||
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
|
||||
sigma_schedule (`str`, *optional*, defaults to `karras`):
|
||||
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
|
||||
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
|
||||
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
solver_order (`int`, defaults to 2):
|
||||
@@ -89,6 +94,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_min: float = 0.002,
|
||||
sigma_max: float = 80.0,
|
||||
sigma_data: float = 0.5,
|
||||
sigma_schedule: str = "karras",
|
||||
num_train_timesteps: int = 1000,
|
||||
prediction_type: str = "epsilon",
|
||||
rho: float = 7.0,
|
||||
@@ -121,7 +127,11 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
ramp = torch.linspace(0, 1, num_train_timesteps)
|
||||
sigmas = self._compute_sigmas(ramp)
|
||||
if sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(ramp)
|
||||
elif sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(ramp)
|
||||
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
self.sigmas = self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
@@ -236,7 +246,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
ramp = np.linspace(0, 1, self.num_inference_steps)
|
||||
sigmas = self._compute_sigmas(ramp)
|
||||
if self.config.sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(ramp)
|
||||
elif self.config.sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(ramp)
|
||||
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
@@ -262,10 +275,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
|
||||
def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
|
||||
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
|
||||
@@ -273,6 +285,18 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
|
||||
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
|
||||
"""Implementation closely follows k-diffusion.
|
||||
|
||||
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -65,6 +66,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
range is [0.2, 80.0].
|
||||
sigma_data (`float`, *optional*, defaults to 0.5):
|
||||
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
|
||||
sigma_schedule (`str`, *optional*, defaults to `karras`):
|
||||
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
|
||||
(https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was
|
||||
incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
@@ -84,15 +89,23 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_min: float = 0.002,
|
||||
sigma_max: float = 80.0,
|
||||
sigma_data: float = 0.5,
|
||||
sigma_schedule: str = "karras",
|
||||
num_train_timesteps: int = 1000,
|
||||
prediction_type: str = "epsilon",
|
||||
rho: float = 7.0,
|
||||
):
|
||||
if sigma_schedule not in ["karras", "exponential"]:
|
||||
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
|
||||
ramp = torch.linspace(0, 1, num_train_timesteps)
|
||||
sigmas = self._compute_sigmas(ramp)
|
||||
if sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(ramp)
|
||||
elif sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(ramp)
|
||||
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
@@ -200,7 +213,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
ramp = np.linspace(0, 1, self.num_inference_steps)
|
||||
sigmas = self._compute_sigmas(ramp)
|
||||
if self.config.sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(ramp)
|
||||
elif self.config.sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(ramp)
|
||||
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
@@ -211,9 +227,8 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
|
||||
def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
|
||||
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
|
||||
@@ -221,6 +236,17 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
|
||||
return sigmas
|
||||
|
||||
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor:
|
||||
"""Implementation closely follows k-diffusion.
|
||||
|
||||
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0)
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
|
||||
Reference in New Issue
Block a user