1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Laplace Scheduler for DDPM (#11320)

* Add Laplace scheduler that samples more around mid-range noise levels (around log SNR=0), increasing performance (lower FID) with faster convergence speed, and robust to resolution and objective. Reference:  https://arxiv.org/pdf/2407.03297.

* Fix copies.

* Apply style fixes

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
gapatron
2026-01-09 14:16:02 -05:00
committed by GitHub
parent 632765a5ee
commit 644169433f
26 changed files with 186 additions and 0 deletions

View File

@@ -40,6 +40,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -75,6 +75,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -74,6 +74,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
@@ -207,6 +214,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
elif beta_schedule == "laplace":
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
elif beta_schedule == "sigmoid":
# GeoDiff sigmoid schedule
betas = torch.linspace(-6, 6, num_train_timesteps)

View File

@@ -76,6 +76,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):
@@ -217,6 +224,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
elif beta_schedule == "laplace":
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
elif beta_schedule == "sigmoid":
# GeoDiff sigmoid schedule
betas = torch.linspace(-6, 6, num_train_timesteps)

View File

@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -78,6 +78,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -143,6 +143,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -62,6 +62,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -80,6 +80,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -78,6 +78,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -79,6 +79,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -75,6 +75,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -54,6 +54,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -73,6 +73,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -61,6 +61,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -78,6 +78,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -74,6 +74,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):

View File

@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
def alpha_bar_fn(t):
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
elif alpha_transform_type == "laplace":
def alpha_bar_fn(t):
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
snr = math.exp(lmb)
return math.sqrt(snr / (1 + snr))
elif alpha_transform_type == "exp":
def alpha_bar_fn(t):