mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[2064]: Add stochastic sampler (sample_dpmpp_sde) (#3020)
* [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * [2064]: Add stochastic sampler * Review comments * [Review comment]: Add is_torchsde_available() * [Review comment]: Test and docs * [Review comment] * [Review comment] * [Review comment] * [Review comment] * [Review comment] --------- Co-authored-by: njindal <njindal@adobe.com>
This commit is contained in:
@@ -266,6 +266,8 @@
|
||||
title: VP-SDE
|
||||
- local: api/schedulers/vq_diffusion
|
||||
title: VQDiffusionScheduler
|
||||
- local: api/schedulers/dpm_sde
|
||||
title: DPMSolverSDEScheduler
|
||||
title: Schedulers
|
||||
- sections:
|
||||
- local: api/experimental/rl
|
||||
|
||||
23
docs/source/en/api/schedulers/dpm_sde.mdx
Normal file
23
docs/source/en/api/schedulers/dpm_sde.mdx
Normal file
@@ -0,0 +1,23 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# DPM Stochastic Scheduler inspired by Karras et. al paper
|
||||
|
||||
## Overview
|
||||
|
||||
Inspired by Stochastic Sampler from [Karras et. al](https://arxiv.org/abs/2206.00364).
|
||||
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
|
||||
|
||||
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
|
||||
|
||||
## DPMSolverSDEScheduler
|
||||
[[autodoc]] DPMSolverSDEScheduler
|
||||
@@ -12,6 +12,7 @@ from .utils import (
|
||||
is_onnx_available,
|
||||
is_scipy_available,
|
||||
is_torch_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
is_unidecode_available,
|
||||
@@ -102,6 +103,13 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
from .schedulers import LMSDiscreteScheduler
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_torchsde_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
|
||||
else:
|
||||
from .schedulers import DPMSolverSDEScheduler
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available()):
|
||||
|
||||
@@ -13,7 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available
|
||||
from ..utils import (
|
||||
OptionalDependencyNotAvailable,
|
||||
is_flax_available,
|
||||
is_scipy_available,
|
||||
is_torch_available,
|
||||
is_torchsde_available,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
@@ -72,3 +78,11 @@ except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
|
||||
else:
|
||||
from .scheduling_lms_discrete import LMSDiscreteScheduler
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_torchsde_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403
|
||||
else:
|
||||
from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler
|
||||
|
||||
447
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
Normal file
447
src/diffusers/schedulers/scheduling_dpmsolver_sde.py
Normal file
@@ -0,0 +1,447 @@
|
||||
# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchsde
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class BatchedBrownianTree:
|
||||
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
|
||||
|
||||
def __init__(self, x, t0, t1, seed=None, **kwargs):
|
||||
t0, t1, self.sign = self.sort(t0, t1)
|
||||
w0 = kwargs.get("w0", torch.zeros_like(x))
|
||||
if seed is None:
|
||||
seed = torch.randint(0, 2**63 - 1, []).item()
|
||||
self.batched = True
|
||||
try:
|
||||
assert len(seed) == x.shape[0]
|
||||
w0 = w0[0]
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
return (a, b, 1) if a < b else (b, a, -1)
|
||||
|
||||
def __call__(self, t0, t1):
|
||||
t0, t1, sign = self.sort(t0, t1)
|
||||
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
|
||||
return w if self.batched else w[0]
|
||||
|
||||
|
||||
class BrownianTreeNoiseSampler:
|
||||
"""A noise sampler backed by a torchsde.BrownianTree.
|
||||
|
||||
Args:
|
||||
x (Tensor): The tensor whose shape, device and dtype to use to generate
|
||||
random samples.
|
||||
sigma_min (float): The low end of the valid interval.
|
||||
sigma_max (float): The high end of the valid interval.
|
||||
seed (int or List[int]): The random seed. If a list of seeds is
|
||||
supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each
|
||||
with its own seed.
|
||||
transform (callable): A function that maps sigma to the sampler's
|
||||
internal timestep.
|
||||
"""
|
||||
|
||||
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
|
||||
self.transform = transform
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
|
||||
self.tree = BatchedBrownianTree(x, t0, t1, seed)
|
||||
|
||||
def __call__(self, sigma, sigma_next):
|
||||
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
|
||||
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Implements Stochastic Sampler (Algorithm 2) from Karras et al. (2022). Based on the original k-diffusion
|
||||
implementation by Katherine Crowson:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/41b4cb6df0506694a7776af31349acf082bf6091/k_diffusion/sampling.py#L543
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
|
||||
[`~SchedulerMixin.from_pretrained`] functions.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
|
||||
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
|
||||
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
|
||||
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
|
||||
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
|
||||
noise_sampler_seed (`int`, *optional*, defaults to `None`):
|
||||
The random seed to use for the noise sampler. If `None`, a random seed will be generated.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.00085, # sensible defaults
|
||||
beta_end: float = 0.012,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
noise_sampler_seed: Optional[int] = None,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# set all values
|
||||
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
||||
self.use_karras_sigmas = use_karras_sigmas
|
||||
self.noise_sampler = None
|
||||
self.noise_sampler_seed = noise_sampler_seed
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
if self.state_in_first_order:
|
||||
pos = -1
|
||||
else:
|
||||
pos = 0
|
||||
return indices[pos].item()
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Args:
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma
|
||||
sample = sample / ((sigma_input**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_train_timesteps: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, optional):
|
||||
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
||||
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
|
||||
if self.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||
|
||||
second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas)
|
||||
|
||||
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = self.sigmas.max()
|
||||
|
||||
timesteps = torch.from_numpy(timesteps)
|
||||
second_order_timesteps = torch.from_numpy(second_order_timesteps)
|
||||
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
||||
timesteps[1::2] = second_order_timesteps
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
# mps does not support float64
|
||||
self.timesteps = timesteps.to(device, dtype=torch.float32)
|
||||
else:
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
|
||||
# empty first order variables
|
||||
self.sample = None
|
||||
self.mid_point_sigma = None
|
||||
|
||||
def _second_order_timesteps(self, sigmas, log_sigmas):
|
||||
def sigma_fn(_t):
|
||||
return np.exp(-_t)
|
||||
|
||||
def t_fn(_sigma):
|
||||
return -np.log(_sigma)
|
||||
|
||||
midpoint_ratio = 0.5
|
||||
t = t_fn(sigmas)
|
||||
delta_time = np.diff(t)
|
||||
t_proposed = t[:-1] + delta_time * midpoint_ratio
|
||||
sig_proposed = sigma_fn(t_proposed)
|
||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed])
|
||||
return timesteps
|
||||
|
||||
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
# get log sigma
|
||||
log_sigma = np.log(sigma)
|
||||
|
||||
# get distribution
|
||||
dists = log_sigma - log_sigmas[:, np.newaxis]
|
||||
|
||||
# get sigmas range
|
||||
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
||||
high_idx = low_idx + 1
|
||||
|
||||
low = log_sigmas[low_idx]
|
||||
high = log_sigmas[high_idx]
|
||||
|
||||
# interpolate sigmas
|
||||
w = (low - log_sigma) / (low - high)
|
||||
w = np.clip(w, 0, 1)
|
||||
|
||||
# transform interpolation to time range
|
||||
t = (1 - w) * low_idx + w * high_idx
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, self.num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
@property
|
||||
def state_in_first_order(self):
|
||||
return self.sample is None
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
return_dict: bool = True,
|
||||
s_noise: float = 1.0,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Args:
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
model_output (Union[torch.FloatTensor, np.ndarray]): Direct output from learned diffusion model.
|
||||
timestep (Union[float, torch.FloatTensor]): Current discrete timestep in the diffusion chain.
|
||||
sample (Union[torch.FloatTensor, np.ndarray]): Current instance of sample being created by diffusion process.
|
||||
return_dict (bool, optional): Option for returning tuple rather than SchedulerOutput class. Defaults to True.
|
||||
s_noise (float, optional): Scaling factor for the noise added to the sample. Defaults to 1.0.
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
step_index = self.index_for_timestep(timestep)
|
||||
|
||||
# Create a noise sampler if it hasn't been created yet
|
||||
if self.noise_sampler is None:
|
||||
min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()
|
||||
self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed)
|
||||
|
||||
# Define functions to compute sigma and t from each other
|
||||
def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return _t.neg().exp()
|
||||
|
||||
def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return _sigma.log().neg()
|
||||
|
||||
if self.state_in_first_order:
|
||||
sigma = self.sigmas[step_index]
|
||||
sigma_next = self.sigmas[step_index + 1]
|
||||
else:
|
||||
# 2nd order
|
||||
sigma = self.sigmas[step_index - 1]
|
||||
sigma_next = self.sigmas[step_index]
|
||||
|
||||
# Set the midpoint and step size for the current step
|
||||
midpoint_ratio = 0.5
|
||||
t, t_next = t_fn(sigma), t_fn(sigma_next)
|
||||
delta_time = t_next - t
|
||||
t_proposed = t + delta_time * midpoint_ratio
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
if self.config.prediction_type == "epsilon":
|
||||
sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed)
|
||||
pred_original_sample = sample - sigma_input * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed)
|
||||
pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
|
||||
sample / (sigma_input**2 + 1)
|
||||
)
|
||||
elif self.config.prediction_type == "sample":
|
||||
raise NotImplementedError("prediction_type not implemented yet: sample")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
if sigma_next == 0:
|
||||
derivative = (sample - pred_original_sample) / sigma
|
||||
dt = sigma_next - sigma
|
||||
prev_sample = sample + derivative * dt
|
||||
else:
|
||||
if self.state_in_first_order:
|
||||
t_next = t_proposed
|
||||
else:
|
||||
sample = self.sample
|
||||
|
||||
sigma_from = sigma_fn(t)
|
||||
sigma_to = sigma_fn(t_next)
|
||||
sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
ancestral_t = t_fn(sigma_down)
|
||||
prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - (
|
||||
t - ancestral_t
|
||||
).expm1() * pred_original_sample
|
||||
prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up
|
||||
|
||||
if self.state_in_first_order:
|
||||
# store for 2nd order step
|
||||
self.sample = sample
|
||||
self.mid_point_sigma = sigma_fn(t_next)
|
||||
else:
|
||||
# free for "first order mode"
|
||||
self.sample = None
|
||||
self.mid_point_sigma = None
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -70,8 +70,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`linear` or `scaled_linear`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
|
||||
@@ -43,6 +43,7 @@ class KarrasDiffusionSchedulers(Enum):
|
||||
KDPM2AncestralDiscreteScheduler = 11
|
||||
DEISMultistepScheduler = 12
|
||||
UniPCMultistepScheduler = 13
|
||||
DPMSolverSDEScheduler = 14
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -70,6 +70,7 @@ from .import_utils import (
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
is_unidecode_available,
|
||||
|
||||
17
src/diffusers/utils/dummy_torch_and_torchsde_objects.py
Normal file
17
src/diffusers/utils/dummy_torch_and_torchsde_objects.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# This file is autogenerated by the command `make fix-copies`, do not edit.
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class DPMSolverSDEScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch", "torchsde"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "torchsde"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "torchsde"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "torchsde"])
|
||||
@@ -287,6 +287,13 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_bs4_available = False
|
||||
|
||||
_torchsde_available = importlib.util.find_spec("torchsde") is not None
|
||||
try:
|
||||
_torchsde_version = importlib_metadata.version("torchsde")
|
||||
logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_torchsde_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
@@ -372,6 +379,10 @@ def is_bs4_available():
|
||||
return _bs4_available
|
||||
|
||||
|
||||
def is_torchsde_available():
|
||||
return _torchsde_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -475,6 +486,11 @@ installation section: https://github.com/rspeer/python-ftfy/tree/master#installi
|
||||
that match your environment. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
TORCHSDE_IMPORT_ERROR = """
|
||||
{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde`
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
@@ -495,6 +511,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
|
||||
("compel", (_compel_available, COMPEL_IMPORT_ERROR)),
|
||||
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
|
||||
("torchsde", (_torchsde_available, TORCHSDE_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from .import_utils import (
|
||||
is_opencv_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_torchsde_available,
|
||||
)
|
||||
from .logging import get_logger
|
||||
|
||||
@@ -216,6 +217,13 @@ def require_note_seq(test_case):
|
||||
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
|
||||
|
||||
|
||||
def require_torchsde(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
|
||||
"""
|
||||
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
|
||||
|
||||
|
||||
def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
|
||||
if isinstance(arry, str):
|
||||
# local_path = "/home/patrick_huggingface_co/"
|
||||
|
||||
156
tests/schedulers/test_scheduler_dpm_sde.py
Normal file
156
tests/schedulers/test_scheduler_dpm_sde.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import torch
|
||||
|
||||
from diffusers import DPMSolverSDEScheduler
|
||||
from diffusers.utils import torch_device
|
||||
from diffusers.utils.testing_utils import require_torchsde
|
||||
|
||||
from .test_schedulers import SchedulerCommonTest
|
||||
|
||||
|
||||
@require_torchsde
|
||||
class DPMSolverSDESchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DPMSolverSDEScheduler,)
|
||||
num_inference_steps = 10
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1100,
|
||||
"beta_start": 0.0001,
|
||||
"beta_end": 0.02,
|
||||
"beta_schedule": "linear",
|
||||
"noise_sampler_seed": 0,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [10, 50, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "scaled_linear"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 167.47821044921875) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
|
||||
else:
|
||||
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
|
||||
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 124.77149200439453) < 1e-2
|
||||
assert abs(result_mean.item() - 0.16226289014816284) < 1e-3
|
||||
else:
|
||||
assert abs(result_sum.item() - 119.8487548828125) < 1e-2
|
||||
assert abs(result_mean.item() - 0.1560530662536621) < 1e-3
|
||||
|
||||
def test_full_loop_device(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 167.46957397460938) < 1e-2
|
||||
assert abs(result_mean.item() - 0.21805934607982635) < 1e-3
|
||||
else:
|
||||
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
|
||||
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
|
||||
|
||||
def test_full_loop_device_karras_sigmas(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True)
|
||||
|
||||
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
if torch_device in ["mps"]:
|
||||
assert abs(result_sum.item() - 176.66974135742188) < 1e-2
|
||||
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
|
||||
else:
|
||||
assert abs(result_sum.item() - 170.3135223388672) < 1e-2
|
||||
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
|
||||
Reference in New Issue
Block a user