mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve docstrings and type hints in scheduling_edm_euler.py (#12871)
* docs: add comprehensive docstrings and refine type hints for EDM scheduler methods and config parameters. * refactor: Add type hints to DPM-Solver scheduler methods.
This commit is contained in:
@@ -143,7 +143,20 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the input sample by scaling it according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor to precondition.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The scaled input sample.
|
||||
"""
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
@@ -155,7 +168,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigma.atan() / math.pi * 2
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
|
||||
def precondition_outputs(self, sample, model_output, sigma):
|
||||
def precondition_outputs(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
model_output: torch.Tensor,
|
||||
sigma: Union[float, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the model outputs according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor.
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The denoised sample computed by combining the skip connection and output scaling.
|
||||
"""
|
||||
sigma_data = self.config.sigma_data
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
|
||||
@@ -173,13 +206,13 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
|
||||
need to scale the denoising model input depending on the current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The input sample tensor.
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
@@ -242,8 +275,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.noise_sampler = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
|
||||
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
def _compute_karras_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values in [0, 1] representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed Karras sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
|
||||
@@ -254,10 +306,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
|
||||
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Implementation closely follows k-diffusion.
|
||||
|
||||
def _compute_exponential_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed exponential sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
@@ -354,7 +423,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.Tensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
sigma_t, sigma_s = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
)
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
@@ -540,7 +612,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
[g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
|
||||
)
|
||||
self.noise_sampler = BrownianTreeNoiseSampler(
|
||||
model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed
|
||||
model_output,
|
||||
sigma_min=self.config.sigma_min,
|
||||
sigma_max=self.config.sigma_max,
|
||||
seed=seed,
|
||||
)
|
||||
noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
|
||||
model_output.device
|
||||
@@ -612,7 +687,18 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
"""
|
||||
Compute the input conditioning factor for the EDM formulation.
|
||||
|
||||
Args:
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`float` or `torch.Tensor`:
|
||||
The input conditioning factor `c_in`.
|
||||
"""
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
|
||||
@@ -175,13 +175,37 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the input sample by scaling it according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor to precondition.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The scaled input sample.
|
||||
"""
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise
|
||||
def precondition_noise(self, sigma):
|
||||
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the noise level by applying a logarithmic transformation.
|
||||
|
||||
Args:
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The sigma (noise level) value to precondition.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The preconditioned noise value computed as `0.25 * log(sigma)`.
|
||||
"""
|
||||
if not isinstance(sigma, torch.Tensor):
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
@@ -190,7 +214,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return c_noise
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
|
||||
def precondition_outputs(self, sample, model_output, sigma):
|
||||
def precondition_outputs(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
model_output: torch.Tensor,
|
||||
sigma: Union[float, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the model outputs according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor.
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The denoised sample computed by combining the skip connection and output scaling.
|
||||
"""
|
||||
sigma_data = self.config.sigma_data
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
|
||||
@@ -208,13 +252,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
|
||||
need to scale the denoising model input depending on the current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The input sample tensor.
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
@@ -274,8 +318,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
|
||||
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
def _compute_karras_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values in [0, 1] representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed Karras sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
|
||||
@@ -286,10 +349,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
|
||||
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Implementation closely follows k-diffusion.
|
||||
|
||||
def _compute_exponential_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed exponential sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
@@ -433,7 +513,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.Tensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
sigma_t, sigma_s = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
)
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
@@ -684,7 +767,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if self.config.algorithm_type == "sde-dpmsolver++":
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=model_output.dtype,
|
||||
)
|
||||
else:
|
||||
noise = None
|
||||
@@ -757,7 +843,18 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
"""
|
||||
Compute the input conditioning factor for the EDM formulation.
|
||||
|
||||
Args:
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`float` or `torch.Tensor`:
|
||||
The input conditioning factor `c_in`.
|
||||
"""
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -57,29 +57,28 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
sigma_min (`float`, *optional*, defaults to 0.002):
|
||||
sigma_min (`float`, *optional*, defaults to `0.002`):
|
||||
Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable
|
||||
range is [0, 10].
|
||||
sigma_max (`float`, *optional*, defaults to 80.0):
|
||||
sigma_max (`float`, *optional*, defaults to `80.0`):
|
||||
Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable
|
||||
range is [0.2, 80.0].
|
||||
sigma_data (`float`, *optional*, defaults to 0.5):
|
||||
sigma_data (`float`, *optional*, defaults to `0.5`):
|
||||
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
|
||||
sigma_schedule (`str`, *optional*, defaults to `karras`):
|
||||
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
|
||||
(https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
|
||||
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
sigma_schedule (`Literal["karras", "exponential"]`, *optional*, defaults to `"karras"`):
|
||||
Sigma schedule to compute the `sigmas`. By default, we use the schedule introduced in the EDM paper
|
||||
(https://huggingface.co/papers/2206.00364). The `"exponential"` schedule was incorporated in this model:
|
||||
https://huggingface.co/stabilityai/cosxl.
|
||||
num_train_timesteps (`int`, *optional*, defaults to `1000`):
|
||||
The number of diffusion steps to train the model.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://huggingface.co/papers/2210.02303) paper).
|
||||
rho (`float`, *optional*, defaults to 7.0):
|
||||
prediction_type (`Literal["epsilon", "v_prediction"]`, *optional*, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function. `"epsilon"` predicts the noise of the diffusion process, and
|
||||
`"v_prediction"` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
|
||||
rho (`float`, *optional*, defaults to `7.0`):
|
||||
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
|
||||
final_sigmas_type (`str`, defaults to `"zero"`):
|
||||
final_sigmas_type (`Literal["zero", "sigma_min"]`, *optional*, defaults to `"zero"`):
|
||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
@@ -91,12 +90,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_min: float = 0.002,
|
||||
sigma_max: float = 80.0,
|
||||
sigma_data: float = 0.5,
|
||||
sigma_schedule: str = "karras",
|
||||
sigma_schedule: Literal["karras", "exponential"] = "karras",
|
||||
num_train_timesteps: int = 1000,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
|
||||
rho: float = 7.0,
|
||||
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
||||
):
|
||||
final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
|
||||
) -> None:
|
||||
if sigma_schedule not in ["karras", "exponential"]:
|
||||
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
|
||||
|
||||
@@ -131,26 +130,41 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
def init_noise_sigma(self) -> float:
|
||||
"""
|
||||
Return the standard deviation of the initial noise distribution.
|
||||
|
||||
Returns:
|
||||
`float`:
|
||||
The initial noise sigma value computed as `(sigma_max**2 + 1) ** 0.5`.
|
||||
"""
|
||||
return (self.config.sigma_max**2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
def step_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
Return the index counter for the current timestep. The index will increase by 1 after each scheduler step.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The current step index, or `None` if not yet initialized.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
def begin_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
Return the index for the first timestep. This should be set from the pipeline with the `set_begin_index`
|
||||
method.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The begin index, or `None` if not yet set.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
@@ -160,12 +174,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the input sample by scaling it according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor to precondition.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The scaled input sample.
|
||||
"""
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
|
||||
def precondition_noise(self, sigma):
|
||||
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the noise level by applying a logarithmic transformation.
|
||||
|
||||
Args:
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The sigma (noise level) value to precondition.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The preconditioned noise value computed as `0.25 * log(sigma)`.
|
||||
"""
|
||||
if not isinstance(sigma, torch.Tensor):
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
@@ -173,7 +211,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return c_noise
|
||||
|
||||
def precondition_outputs(self, sample, model_output, sigma):
|
||||
def precondition_outputs(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
model_output: torch.Tensor,
|
||||
sigma: Union[float, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the model outputs according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor.
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The denoised sample computed by combining the skip connection and output scaling.
|
||||
"""
|
||||
sigma_data = self.config.sigma_data
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
|
||||
@@ -190,13 +248,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
|
||||
need to scale the denoising model input depending on the current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The input sample tensor.
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
@@ -214,19 +272,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
num_inference_steps (`int`, *optional*):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
|
||||
sigmas (`torch.Tensor` or `List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process. If not defined, the default behavior when
|
||||
`num_inference_steps` is passed will be used.
|
||||
"""
|
||||
@@ -262,8 +320,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
|
||||
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
def _compute_karras_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values in [0, 1] representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed Karras sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
|
||||
@@ -273,10 +350,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Implementation closely follows k-diffusion.
|
||||
|
||||
def _compute_exponential_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed exponential sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
@@ -342,32 +436,38 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
pred_original_sample: Optional[torch.Tensor] = None,
|
||||
) -> Union[EDMEulerSchedulerOutput, Tuple]:
|
||||
) -> Union[EDMEulerSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
s_churn (`float`):
|
||||
s_tmin (`float`):
|
||||
s_tmax (`float`):
|
||||
s_noise (`float`, defaults to 1.0):
|
||||
s_churn (`float`, *optional*, defaults to `0.0`):
|
||||
The amount of stochasticity to add at each step. Higher values add more noise.
|
||||
s_tmin (`float`, *optional*, defaults to `0.0`):
|
||||
The minimum sigma threshold below which no noise is added.
|
||||
s_tmax (`float`, *optional*, defaults to `float("inf")`):
|
||||
The maximum sigma threshold above which no noise is added.
|
||||
s_noise (`float`, *optional*, defaults to `1.0`):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple.
|
||||
A random number generator for reproducibility.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or tuple.
|
||||
pred_original_sample (`torch.Tensor`, *optional*):
|
||||
The predicted denoised sample from a previous step. If provided, skips recomputation.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
[`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the previous sample tensor and the
|
||||
second element is the predicted original sample tensor.
|
||||
"""
|
||||
|
||||
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
||||
@@ -399,7 +499,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if gamma > 0:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
||||
model_output.shape,
|
||||
dtype=model_output.dtype,
|
||||
device=model_output.device,
|
||||
generator=generator,
|
||||
)
|
||||
eps = noise * s_noise
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
@@ -478,9 +581,20 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
"""
|
||||
Compute the input conditioning factor for the EDM formulation.
|
||||
|
||||
Args:
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`float` or `torch.Tensor`:
|
||||
The input conditioning factor `c_in`.
|
||||
"""
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
Reference in New Issue
Block a user