mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Laplace Scheduler for DDPM (#11320)
* Add Laplace scheduler that samples more around mid-range noise levels (around log SNR=0), increasing performance (lower FID) with faster convergence speed, and robust to resolution and objective. Reference: https://arxiv.org/pdf/2407.03297. * Fix copies. * Apply style fixes --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -40,6 +40,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -75,6 +75,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -74,6 +74,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
@@ -207,6 +214,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
elif beta_schedule == "laplace":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
|
||||
elif beta_schedule == "sigmoid":
|
||||
# GeoDiff sigmoid schedule
|
||||
betas = torch.linspace(-6, 6, num_train_timesteps)
|
||||
|
||||
@@ -76,6 +76,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
@@ -217,6 +224,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
elif beta_schedule == "laplace":
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps, alpha_transform_type="laplace")
|
||||
elif beta_schedule == "sigmoid":
|
||||
# GeoDiff sigmoid schedule
|
||||
betas = torch.linspace(-6, 6, num_train_timesteps)
|
||||
|
||||
@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -78,6 +78,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -143,6 +143,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -62,6 +62,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -80,6 +80,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -78,6 +78,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -77,6 +77,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -79,6 +79,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -75,6 +75,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -54,6 +54,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -73,6 +73,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -61,6 +61,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -78,6 +78,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -74,6 +74,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
@@ -60,6 +60,13 @@ def betas_for_alpha_bar(
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "laplace":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
lmb = -0.5 * math.copysign(1, 0.5 - t) * math.log(1 - 2 * math.fabs(0.5 - t) + 1e-6)
|
||||
snr = math.exp(lmb)
|
||||
return math.sqrt(snr / (1 + snr))
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
|
||||
Reference in New Issue
Block a user