1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Create "loglinear_sigmas" schedule.

Currently implemented in EulerDiscreteScheduler. An alternative would
have been to initialize the scheduler with an array of `trained_betas`.
However, that is currently not possible because of #1367.
This commit is contained in:
Pedro Cuenca
2022-11-22 21:41:03 +01:00
parent b35a75a7f7
commit 70eb8970cc
2 changed files with 51 additions and 1 deletions

View File

@@ -20,7 +20,7 @@ import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, betas_from_loglinear_sigmas
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -91,6 +91,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "loglinear_sigmas":
# This scheduler is specific to k-diffusion latent upscaler
# We use a helper function because the computation is a bit involved
# Alternative: create from a list of `trained_betas` (but see https://github.com/huggingface/diffusers/issues/1367)
self.betas = betas_from_loglinear_sigmas(beta_start, beta_end, num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

View File

@@ -16,6 +16,7 @@ import os
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union
import numpy as np
import torch
from ..utils import BaseOutput
@@ -152,3 +153,47 @@ class SchedulerMixin:
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
]
return compatible_classes
def betas_from_loglinear_sigmas(beta_start, beta_end, num_timesteps):
"""
Computes the beta values suitable to create a loglinear schedule of sigmas,
as used in k-diffusion latent upscaler.
Concretely, these are the betas the create a sigma schedule like the following:
```
torch.linspace(np.log(sigma_max), np.log(sigma_min), num_timesteps).exp()
```
Args:
beta_start (`float`): The start sigma value.
beta_end (`float`): The end sigma value.
num_timesteps (`int`): The number of training timesteps.
Returns:
`torch.FloatTensor`: The beta values.
"""
# First, compute sigma_max and sigma_min considering a "scaled_linear" schedule
# as used in Stable Diffusion. We just need sigma_min and sigma_max.
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
sigma_min, sigma_max = sigmas[0], sigmas[-1]
# Then, compute the actual loglinear sigmas from sigma_min, sigma_max
sigmas = torch.linspace(np.log(sigma_max), np.log(sigma_min), num_timesteps)
sigmas = sigmas.exp()
sigmas = np.array(sigmas)[::-1]
alpha_cumprod = 1./(1+sigmas**2)
# Compute the alpha values reversing alpha_cumprod
alphas = []
prev_prod = 1.
for a in alpha_cumprod:
current_alpha = a / prev_prod
alphas.append(current_alpha)
prev_prod = a
# Get the betas from the alphas
betas = 1 - np.array(alphas)
return torch.Tensor(betas)