1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

EDMEulerScheduler accept sigmas, add final_sigmas_type (#10734)

This commit is contained in:
hlky
2025-02-07 06:53:52 +00:00
committed by GitHub
parent d43ce14e2d
commit 464374fb87

View File

@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import torch
@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
Video](https://imagen.research.google/video/paper.pdf) paper).
rho (`float`, *optional*, defaults to 7.0):
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
"""
_compatibles = []
@@ -92,6 +95,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
rho: float = 7.0,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
):
if sigma_schedule not in ["karras", "exponential"]:
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
@@ -99,15 +103,24 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
# setable values
self.num_inference_steps = None
ramp = torch.linspace(0, 1, num_train_timesteps)
sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
if sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
sigmas = self._compute_karras_sigmas(sigmas)
elif sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
sigmas = self._compute_exponential_sigmas(sigmas)
self.timesteps = self.precondition_noise(sigmas)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
self.is_scale_input_called = False
@@ -197,7 +210,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = True
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -206,19 +224,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
Custom sigmas to use for the denoising process. If not defined, the default behavior when
`num_inference_steps` is passed will be used.
"""
self.num_inference_steps = num_inference_steps
ramp = torch.linspace(0, 1, self.num_inference_steps)
if sigmas is None:
sigmas = torch.linspace(0, 1, self.num_inference_steps)
elif isinstance(sigmas, float):
sigmas = torch.tensor(sigmas, dtype=torch.float32)
else:
sigmas = sigmas
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
sigmas = self._compute_karras_sigmas(sigmas)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
sigmas = self._compute_exponential_sigmas(sigmas)
sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication