From 70eb8970ccb744c0a41e175bc4f4d30a74db267e Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 22 Nov 2022 21:41:03 +0100 Subject: [PATCH] Create "loglinear_sigmas" schedule. Currently implemented in EulerDiscreteScheduler. An alternative would have been to initialize the scheduler with an array of `trained_betas`. However, that is currently not possible because of #1367. --- .../schedulers/scheduling_euler_discrete.py | 7 ++- src/diffusers/schedulers/scheduling_utils.py | 45 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 3fb8c1f0b3..ea6fcba2ef 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -20,7 +20,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import SchedulerMixin, betas_from_loglinear_sigmas logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -91,6 +91,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): self.betas = ( torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 ) + elif beta_schedule == "loglinear_sigmas": + # This scheduler is specific to k-diffusion latent upscaler + # We use a helper function because the computation is a bit involved + # Alternative: create from a list of `trained_betas` (but see https://github.com/huggingface/diffusers/issues/1367) + self.betas = betas_from_loglinear_sigmas(beta_start, beta_end, num_train_timesteps) else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 90ab674e38..4c5dc53020 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -16,6 +16,7 @@ import os from dataclasses import dataclass from typing import Any, Dict, Optional, Union +import numpy as np import torch from ..utils import BaseOutput @@ -152,3 +153,47 @@ class SchedulerMixin: getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes + +def betas_from_loglinear_sigmas(beta_start, beta_end, num_timesteps): + """ + Computes the beta values suitable to create a loglinear schedule of sigmas, + as used in k-diffusion latent upscaler. + + Concretely, these are the betas the create a sigma schedule like the following: + ``` + torch.linspace(np.log(sigma_max), np.log(sigma_min), num_timesteps).exp() + ``` + + Args: + beta_start (`float`): The start sigma value. + beta_end (`float`): The end sigma value. + num_timesteps (`int`): The number of training timesteps. + + Returns: + `torch.FloatTensor`: The beta values. + """ + # First, compute sigma_max and sigma_min considering a "scaled_linear" schedule + # as used in Stable Diffusion. We just need sigma_min and sigma_max. + betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2 + alphas = 1 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + sigma_min, sigma_max = sigmas[0], sigmas[-1] + + # Then, compute the actual loglinear sigmas from sigma_min, sigma_max + sigmas = torch.linspace(np.log(sigma_max), np.log(sigma_min), num_timesteps) + sigmas = sigmas.exp() + sigmas = np.array(sigmas)[::-1] + alpha_cumprod = 1./(1+sigmas**2) + + # Compute the alpha values reversing alpha_cumprod + alphas = [] + prev_prod = 1. + for a in alpha_cumprod: + current_alpha = a / prev_prod + alphas.append(current_alpha) + prev_prod = a + + # Get the betas from the alphas + betas = 1 - np.array(alphas) + return torch.Tensor(betas)