mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add use_Karras_sigmas to DPMSolverSinglestepScheduler (#3476)
* add use_karras_sigmas * add karras test * add doc
This commit is contained in:
@@ -117,6 +117,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
lower_order_final (`bool`, default `True`):
|
||||
whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable
|
||||
this to use up all the function evaluations.
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
|
||||
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
|
||||
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
|
||||
lambda_min_clipped (`float`, default `-inf`):
|
||||
the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for
|
||||
cosine (squaredcos_cap_v2) noise schedule.
|
||||
@@ -150,6 +154,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
):
|
||||
@@ -197,6 +202,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.sample = None
|
||||
self.order_list = self.get_order_list(num_train_timesteps)
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
|
||||
def get_order_list(self, num_inference_steps: int) -> List[int]:
|
||||
"""
|
||||
@@ -252,6 +258,14 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
.copy()
|
||||
.astype(np.int64)
|
||||
)
|
||||
|
||||
if self.use_karras_sigmas:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.model_outputs = [None] * self.config.solver_order
|
||||
self.sample = None
|
||||
@@ -299,6 +313,44 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
@@ -215,12 +215,24 @@ class DPMSolverSinglestepSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert abs(result_mean.item() - 0.2791) < 1e-3
|
||||
|
||||
def test_full_loop_with_karras(self):
|
||||
sample = self.full_loop(use_karras_sigmas=True)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.2248) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.1453) < 1e-3
|
||||
|
||||
def test_full_loop_with_karras_and_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.0649) < 1e-3
|
||||
|
||||
def test_fp16_support(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0)
|
||||
|
||||
Reference in New Issue
Block a user