mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Create "loglinear_sigmas" schedule.
Currently implemented in EulerDiscreteScheduler. An alternative would have been to initialize the scheduler with an array of `trained_betas`. However, that is currently not possible because of #1367.
This commit is contained in:
@@ -20,7 +20,7 @@ import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_utils import SchedulerMixin, betas_from_loglinear_sigmas
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -91,6 +91,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "loglinear_sigmas":
|
||||
# This scheduler is specific to k-diffusion latent upscaler
|
||||
# We use a helper function because the computation is a bit involved
|
||||
# Alternative: create from a list of `trained_betas` (but see https://github.com/huggingface/diffusers/issues/1367)
|
||||
self.betas = betas_from_loglinear_sigmas(beta_start, beta_end, num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..utils import BaseOutput
|
||||
@@ -152,3 +153,47 @@ class SchedulerMixin:
|
||||
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
||||
]
|
||||
return compatible_classes
|
||||
|
||||
def betas_from_loglinear_sigmas(beta_start, beta_end, num_timesteps):
|
||||
"""
|
||||
Computes the beta values suitable to create a loglinear schedule of sigmas,
|
||||
as used in k-diffusion latent upscaler.
|
||||
|
||||
Concretely, these are the betas the create a sigma schedule like the following:
|
||||
```
|
||||
torch.linspace(np.log(sigma_max), np.log(sigma_min), num_timesteps).exp()
|
||||
```
|
||||
|
||||
Args:
|
||||
beta_start (`float`): The start sigma value.
|
||||
beta_end (`float`): The end sigma value.
|
||||
num_timesteps (`int`): The number of training timesteps.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: The beta values.
|
||||
"""
|
||||
# First, compute sigma_max and sigma_min considering a "scaled_linear" schedule
|
||||
# as used in Stable Diffusion. We just need sigma_min and sigma_max.
|
||||
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2
|
||||
alphas = 1 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
sigma_min, sigma_max = sigmas[0], sigmas[-1]
|
||||
|
||||
# Then, compute the actual loglinear sigmas from sigma_min, sigma_max
|
||||
sigmas = torch.linspace(np.log(sigma_max), np.log(sigma_min), num_timesteps)
|
||||
sigmas = sigmas.exp()
|
||||
sigmas = np.array(sigmas)[::-1]
|
||||
alpha_cumprod = 1./(1+sigmas**2)
|
||||
|
||||
# Compute the alpha values reversing alpha_cumprod
|
||||
alphas = []
|
||||
prev_prod = 1.
|
||||
for a in alpha_cumprod:
|
||||
current_alpha = a / prev_prod
|
||||
alphas.append(current_alpha)
|
||||
prev_prod = a
|
||||
|
||||
# Get the betas from the alphas
|
||||
betas = 1 - np.array(alphas)
|
||||
return torch.Tensor(betas)
|
||||
|
||||
Reference in New Issue
Block a user