1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add exponential sigmas to other schedulers and update docs (#9518)

This commit is contained in:
hlky
2024-09-25 01:50:12 +01:00
committed by GitHub
parent bac8a2412d
commit b52684c3ed
13 changed files with 345 additions and 9 deletions

View File

@@ -46,11 +46,12 @@ Many schedulers are implemented from the [k-diffusion](https://github.com/crowso
| N/A | [`UniPCMultistepScheduler`] | |
## Noise schedules and schedule types
| A1111/k-diffusion | 🤗 Diffusers |
|---------------------|----------------------------------------|
| Karras | init with `use_karras_sigmas=True` |
| sgm_uniform | init with `timestep_spacing="trailing"`|
| simple | init with `timestep_spacing="trailing"`|
| A1111/k-diffusion | 🤗 Diffusers |
|--------------------------|----------------------------------------------------------------------------|
| Karras | init with `use_karras_sigmas=True` |
| sgm_uniform | init with `timestep_spacing="trailing"` |
| simple | init with `timestep_spacing="trailing"` |
| exponential | init with `timestep_spacing="linspace"`, `use_exponential_sigmas=True` |
All schedulers are built from the base [`SchedulerMixin`] class which implements low level utilities shared by all schedulers.

View File

@@ -111,6 +111,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -138,9 +140,12 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
solver_type: str = "logrho",
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -255,6 +260,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -366,6 +374,28 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas
def convert_model_output(
self,
model_output: torch.Tensor,

View File

@@ -161,6 +161,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
use_lu_lambdas (`bool`, *optional*, defaults to `False`):
Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during
the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of
@@ -206,6 +208,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_lu_lambdas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
@@ -214,6 +217,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
@@ -330,6 +335,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
if timesteps is not None and self.config.use_lu_lambdas:
raise ValueError("Cannot use `timesteps` with `config.use_lu_lambdas = True`")
if timesteps is not None and self.config.use_exponential_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.int64)
@@ -378,6 +385,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
lambdas = self._convert_to_lu(in_lambdas=lambdas, num_inference_steps=num_inference_steps)
sigmas = np.exp(lambdas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
@@ -510,6 +520,28 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
lambdas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return lambdas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas
def convert_model_output(
self,
model_output: torch.Tensor,

View File

@@ -124,6 +124,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -158,11 +160,14 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
@@ -213,6 +218,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.use_karras_sigmas = use_karras_sigmas
self.use_exponential_sigmas = use_exponential_sigmas
@property
def step_index(self):
@@ -267,6 +273,9 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
timesteps = timesteps.copy().astype(np.int64)
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_max = (
@@ -385,6 +394,28 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output
def convert_model_output(
self,

View File

@@ -160,6 +160,8 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
noise_sampler_seed (`int`, *optional*, defaults to `None`):
The random seed to use for the noise sampler. If `None`, a random seed is generated.
timestep_spacing (`str`, defaults to `"linspace"`):
@@ -182,10 +184,13 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
noise_sampler_seed: Optional[int] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -341,6 +346,9 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas)
@@ -421,6 +429,28 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas
@property
def state_in_first_order(self):
return self.sample is None

View File

@@ -123,6 +123,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
@@ -154,10 +156,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
solver_type: str = "midpoint",
lower_order_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if algorithm_type == "dpmsolver":
deprecation_message = "algorithm_type `dpmsolver` is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate("algorithm_types=dpmsolver", "1.0.0", deprecation_message)
@@ -300,6 +305,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError("Cannot use `timesteps` when `config.use_karras_sigmas=True`.")
if timesteps is not None and self.config.use_exponential_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
num_inference_steps = num_inference_steps or len(timesteps)
self.num_inference_steps = num_inference_steps
@@ -323,6 +330,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
@@ -452,6 +462,28 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas
def convert_model_output(
self,
model_output: torch.Tensor,

View File

@@ -197,6 +197,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -338,10 +340,6 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.")
if timesteps is not None and self.config.use_exponential_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
if self.config.use_exponential_sigmas and self.config.use_karras_sigmas:
raise ValueError(
"Cannot set both `config.use_exponential_sigmas = True` and config.use_karras_sigmas = True`"
)
if (
timesteps is not None
and self.config.timestep_type == "continuous"

View File

@@ -97,6 +97,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -117,11 +119,14 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
clip_sample: Optional[bool] = False,
clip_sample_range: float = 1.0,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -251,6 +256,8 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
if timesteps is not None and self.config.use_karras_sigmas:
raise ValueError("Cannot use `timesteps` with `config.use_karras_sigmas = True`")
if timesteps is not None and self.config.use_exponential_sigmas:
raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.")
num_inference_steps = num_inference_steps or len(timesteps)
self.num_inference_steps = num_inference_steps
@@ -286,6 +293,9 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
@@ -354,6 +364,28 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas
@property
def state_in_first_order(self):
return self.dt is None

View File

@@ -91,6 +91,8 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -114,10 +116,13 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -250,6 +255,9 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
self.log_sigmas = torch.from_numpy(log_sigmas).to(device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
@@ -346,6 +354,28 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas
@property
def state_in_first_order(self):
return self.sample is None

View File

@@ -90,6 +90,8 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -113,10 +115,13 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -249,6 +254,9 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
self.log_sigmas = torch.from_numpy(log_sigmas).to(device=device)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
@@ -359,6 +367,28 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas
def step(
self,
model_output: Union[torch.Tensor, np.ndarray],

View File

@@ -111,6 +111,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
@@ -134,10 +136,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -289,6 +294,9 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
@@ -362,6 +370,28 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas
def step(
self,
model_output: torch.Tensor,

View File

@@ -122,6 +122,8 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -156,11 +158,14 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
algorithm_type: str = "data_prediction",
lower_order_final: bool = True,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -284,6 +289,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -395,6 +403,28 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas
def convert_model_output(
self,
model_output: torch.Tensor,

View File

@@ -159,6 +159,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`,
the sigmas are determined according to a sequence of noise levels {σi}.
use_exponential_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -195,11 +197,14 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
disable_corrector: List[int] = [],
solver_p: SchedulerMixin = None,
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
rescale_betas_zero_snr: bool = False,
):
if sum([self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used.")
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
@@ -329,6 +334,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
else:
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
@@ -450,6 +458,28 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""Constructs an exponential noise schedule."""
# Hack to make sure that other schedulers which copy this function don't break
# TODO: Add this logic to the other schedulers
if hasattr(self.config, "sigma_min"):
sigma_min = self.config.sigma_min
else:
sigma_min = None
if hasattr(self.config, "sigma_max"):
sigma_max = self.config.sigma_max
else:
sigma_max = None
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp()
return sigmas
def convert_model_output(
self,
model_output: torch.Tensor,