mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
EDMEulerScheduler accept sigmas, add final_sigmas_type (#10734)
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||
rho (`float`, *optional*, defaults to 7.0):
|
||||
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
|
||||
final_sigmas_type (`str`, defaults to `"zero"`):
|
||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
@@ -92,6 +95,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
prediction_type: str = "epsilon",
|
||||
rho: float = 7.0,
|
||||
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
||||
):
|
||||
if sigma_schedule not in ["karras", "exponential"]:
|
||||
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
|
||||
@@ -99,15 +103,24 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
|
||||
ramp = torch.linspace(0, 1, num_train_timesteps)
|
||||
sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
|
||||
if sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(ramp)
|
||||
sigmas = self._compute_karras_sigmas(sigmas)
|
||||
elif sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(ramp)
|
||||
sigmas = self._compute_exponential_sigmas(sigmas)
|
||||
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = sigmas[-1]
|
||||
elif self.config.final_sigmas_type == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
||||
)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
|
||||
|
||||
self.is_scale_input_called = False
|
||||
|
||||
@@ -197,7 +210,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -206,19 +224,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
|
||||
Custom sigmas to use for the denoising process. If not defined, the default behavior when
|
||||
`num_inference_steps` is passed will be used.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
ramp = torch.linspace(0, 1, self.num_inference_steps)
|
||||
if sigmas is None:
|
||||
sigmas = torch.linspace(0, 1, self.num_inference_steps)
|
||||
elif isinstance(sigmas, float):
|
||||
sigmas = torch.tensor(sigmas, dtype=torch.float32)
|
||||
else:
|
||||
sigmas = sigmas
|
||||
if self.config.sigma_schedule == "karras":
|
||||
sigmas = self._compute_karras_sigmas(ramp)
|
||||
sigmas = self._compute_karras_sigmas(sigmas)
|
||||
elif self.config.sigma_schedule == "exponential":
|
||||
sigmas = self._compute_exponential_sigmas(ramp)
|
||||
sigmas = self._compute_exponential_sigmas(sigmas)
|
||||
|
||||
sigmas = sigmas.to(dtype=torch.float32, device=device)
|
||||
self.timesteps = self.precondition_noise(sigmas)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = sigmas[-1]
|
||||
elif self.config.final_sigmas_type == "zero":
|
||||
sigma_last = 0
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
|
||||
)
|
||||
|
||||
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
Reference in New Issue
Block a user