1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Improve docstrings and type hints in scheduling_edm_euler.py (#12871)

* docs: add comprehensive docstrings and refine type hints for EDM scheduler methods and config parameters.

* refactor: Add type hints to DPM-Solver scheduler methods.
This commit is contained in:
David El Malih
2026-01-07 20:18:00 +01:00
committed by GitHub
parent 6fb4c99f5a
commit 9fb6b89d49
3 changed files with 386 additions and 89 deletions

View File

@@ -143,7 +143,20 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
def precondition_inputs(self, sample, sigma):
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the input sample by scaling it according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor to precondition.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The scaled input sample.
"""
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
@@ -155,7 +168,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigma.atan() / math.pi * 2
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
def precondition_outputs(self, sample, model_output, sigma):
def precondition_outputs(
self,
sample: torch.Tensor,
model_output: torch.Tensor,
sigma: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
Precondition the model outputs according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor.
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The denoised sample computed by combining the skip connection and output scaling.
"""
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
@@ -173,13 +206,13 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
need to scale the denoising model input depending on the current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The input sample tensor.
timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
@@ -242,8 +275,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.noise_sampler = None
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
def _compute_karras_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
Args:
ramp (`torch.Tensor`):
A tensor of values in [0, 1] representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed Karras sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -254,10 +306,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion.
def _compute_exponential_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
Args:
ramp (`torch.Tensor`):
A tensor of values representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed exponential sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -354,7 +423,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -540,7 +612,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
[g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
)
self.noise_sampler = BrownianTreeNoiseSampler(
model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed
model_output,
sigma_min=self.config.sigma_min,
sigma_max=self.config.sigma_max,
seed=seed,
)
noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
model_output.device
@@ -612,7 +687,18 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
def _get_conditioning_c_in(self, sigma):
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
"""
Compute the input conditioning factor for the EDM formulation.
Args:
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`float` or `torch.Tensor`:
The input conditioning factor `c_in`.
"""
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in

View File

@@ -175,13 +175,37 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
def precondition_inputs(self, sample, sigma):
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the input sample by scaling it according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor to precondition.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The scaled input sample.
"""
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise
def precondition_noise(self, sigma):
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the noise level by applying a logarithmic transformation.
Args:
sigma (`float` or `torch.Tensor`):
The sigma (noise level) value to precondition.
Returns:
`torch.Tensor`:
The preconditioned noise value computed as `0.25 * log(sigma)`.
"""
if not isinstance(sigma, torch.Tensor):
sigma = torch.tensor([sigma])
@@ -190,7 +214,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return c_noise
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
def precondition_outputs(self, sample, model_output, sigma):
def precondition_outputs(
self,
sample: torch.Tensor,
model_output: torch.Tensor,
sigma: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
Precondition the model outputs according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor.
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The denoised sample computed by combining the skip connection and output scaling.
"""
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
@@ -208,13 +252,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
need to scale the denoising model input depending on the current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The input sample tensor.
timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
@@ -274,8 +318,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
def _compute_karras_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
Args:
ramp (`torch.Tensor`):
A tensor of values in [0, 1] representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed Karras sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -286,10 +349,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion.
def _compute_exponential_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
Args:
ramp (`torch.Tensor`):
A tensor of values representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed exponential sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -433,7 +513,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -684,7 +767,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.config.algorithm_type == "sde-dpmsolver++":
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
else:
noise = None
@@ -757,7 +843,18 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
def _get_conditioning_c_in(self, sigma):
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
"""
Compute the input conditioning factor for the EDM formulation.
Args:
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`float` or `torch.Tensor`:
The input conditioning factor `c_in`.
"""
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in

View File

@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
import torch
@@ -57,29 +57,28 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
sigma_min (`float`, *optional*, defaults to 0.002):
sigma_min (`float`, *optional*, defaults to `0.002`):
Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable
range is [0, 10].
sigma_max (`float`, *optional*, defaults to 80.0):
sigma_max (`float`, *optional*, defaults to `80.0`):
Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable
range is [0.2, 80.0].
sigma_data (`float`, *optional*, defaults to 0.5):
sigma_data (`float`, *optional*, defaults to `0.5`):
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
sigma_schedule (`str`, *optional*, defaults to `karras`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
(https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
sigma_schedule (`Literal["karras", "exponential"]`, *optional*, defaults to `"karras"`):
Sigma schedule to compute the `sigmas`. By default, we use the schedule introduced in the EDM paper
(https://huggingface.co/papers/2206.00364). The `"exponential"` schedule was incorporated in this model:
https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, *optional*, defaults to `1000`):
The number of diffusion steps to train the model.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper).
rho (`float`, *optional*, defaults to 7.0):
prediction_type (`Literal["epsilon", "v_prediction"]`, *optional*, defaults to `"epsilon"`):
Prediction type of the scheduler function. `"epsilon"` predicts the noise of the diffusion process, and
`"v_prediction"` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
rho (`float`, *optional*, defaults to `7.0`):
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
final_sigmas_type (`str`, defaults to `"zero"`):
final_sigmas_type (`Literal["zero", "sigma_min"]`, *optional*, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
"""
_compatibles = []
@@ -91,12 +90,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
sigma_min: float = 0.002,
sigma_max: float = 80.0,
sigma_data: float = 0.5,
sigma_schedule: str = "karras",
sigma_schedule: Literal["karras", "exponential"] = "karras",
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
rho: float = 7.0,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
):
final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
) -> None:
if sigma_schedule not in ["karras", "exponential"]:
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
@@ -131,26 +130,41 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
def init_noise_sigma(self) -> float:
"""
Return the standard deviation of the initial noise distribution.
Returns:
`float`:
The initial noise sigma value computed as `(sigma_max**2 + 1) ** 0.5`.
"""
return (self.config.sigma_max**2 + 1) ** 0.5
@property
def step_index(self):
def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
Return the index counter for the current timestep. The index will increase by 1 after each scheduler step.
Returns:
`int` or `None`:
The current step index, or `None` if not yet initialized.
"""
return self._step_index
@property
def begin_index(self):
def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
Return the index for the first timestep. This should be set from the pipeline with the `set_begin_index`
method.
Returns:
`int` or `None`:
The begin index, or `None` if not yet set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -160,12 +174,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def precondition_inputs(self, sample, sigma):
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the input sample by scaling it according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor to precondition.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The scaled input sample.
"""
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
def precondition_noise(self, sigma):
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the noise level by applying a logarithmic transformation.
Args:
sigma (`float` or `torch.Tensor`):
The sigma (noise level) value to precondition.
Returns:
`torch.Tensor`:
The preconditioned noise value computed as `0.25 * log(sigma)`.
"""
if not isinstance(sigma, torch.Tensor):
sigma = torch.tensor([sigma])
@@ -173,7 +211,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
return c_noise
def precondition_outputs(self, sample, model_output, sigma):
def precondition_outputs(
self,
sample: torch.Tensor,
model_output: torch.Tensor,
sigma: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
Precondition the model outputs according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor.
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The denoised sample computed by combining the skip connection and output scaling.
"""
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
@@ -190,13 +248,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
need to scale the denoising model input depending on the current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The input sample tensor.
timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
@@ -214,19 +272,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
):
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
sigmas (`torch.Tensor` or `List[float]`, *optional*):
Custom sigmas to use for the denoising process. If not defined, the default behavior when
`num_inference_steps` is passed will be used.
"""
@@ -262,8 +320,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
def _compute_karras_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
Args:
ramp (`torch.Tensor`):
A tensor of values in [0, 1] representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed Karras sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -273,10 +350,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion.
def _compute_exponential_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
Args:
ramp (`torch.Tensor`):
A tensor of values representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed exponential sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -342,32 +436,38 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
pred_original_sample: Optional[torch.Tensor] = None,
) -> Union[EDMEulerSchedulerOutput, Tuple]:
) -> Union[EDMEulerSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The direct output from the learned diffusion model.
timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
s_churn (`float`, *optional*, defaults to `0.0`):
The amount of stochasticity to add at each step. Higher values add more noise.
s_tmin (`float`, *optional*, defaults to `0.0`):
The minimum sigma threshold below which no noise is added.
s_tmax (`float`, *optional*, defaults to `float("inf")`):
The maximum sigma threshold above which no noise is added.
s_noise (`float`, *optional*, defaults to `1.0`):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple.
A random number generator for reproducibility.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or tuple.
pred_original_sample (`torch.Tensor`, *optional*):
The predicted denoised sample from a previous step. If provided, skips recomputation.
Returns:
[`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
[`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or `tuple`:
If `return_dict` is `True`, an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the previous sample tensor and the
second element is the predicted original sample tensor.
"""
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
@@ -399,7 +499,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
if gamma > 0:
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
model_output.shape,
dtype=model_output.dtype,
device=model_output.device,
generator=generator,
)
eps = noise * s_noise
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
@@ -478,9 +581,20 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma
return noisy_samples
def _get_conditioning_c_in(self, sigma):
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
"""
Compute the input conditioning factor for the EDM formulation.
Args:
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`float` or `torch.Tensor`:
The input conditioning factor `c_in`.
"""
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps