1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[2064]: Add stochastic sampler (sample_dpmpp_sde) (#3020)

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* [2064]: Add stochastic sampler

* Review comments

* [Review comment]: Add is_torchsde_available()

* [Review comment]: Test and docs

* [Review comment]

* [Review comment]

* [Review comment]

* [Review comment]

* [Review comment]

---------

Co-authored-by: njindal <njindal@adobe.com>
This commit is contained in:
Nipun Jindal
2023-04-27 11:18:38 +05:30
committed by GitHub
parent e0a2bd15f9
commit fd512d7461
12 changed files with 695 additions and 3 deletions

View File

@@ -266,6 +266,8 @@
title: VP-SDE
- local: api/schedulers/vq_diffusion
title: VQDiffusionScheduler
- local: api/schedulers/dpm_sde
title: DPMSolverSDEScheduler
title: Schedulers
- sections:
- local: api/experimental/rl

View File

@@ -0,0 +1,23 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# DPM Stochastic Scheduler inspired by Karras et. al paper
## Overview
Inspired by Stochastic Sampler from [Karras et. al](https://arxiv.org/abs/2206.00364).
Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library:
All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/)
## DPMSolverSDEScheduler
[[autodoc]] DPMSolverSDEScheduler

View File

@@ -12,6 +12,7 @@ from .utils import (
is_onnx_available,
is_scipy_available,
is_torch_available,
is_torchsde_available,
is_transformers_available,
is_transformers_version,
is_unidecode_available,
@@ -102,6 +103,13 @@ except OptionalDependencyNotAvailable:
else:
from .schedulers import LMSDiscreteScheduler
try:
if not (is_torch_available() and is_torchsde_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_torchsde_objects import * # noqa F403
else:
from .schedulers import DPMSolverSDEScheduler
try:
if not (is_torch_available() and is_transformers_available()):

View File

@@ -13,7 +13,13 @@
# limitations under the License.
from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available
from ..utils import (
OptionalDependencyNotAvailable,
is_flax_available,
is_scipy_available,
is_torch_available,
is_torchsde_available,
)
try:
@@ -72,3 +78,11 @@ except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_scipy_objects import * # noqa F403
else:
from .scheduling_lms_discrete import LMSDiscreteScheduler
try:
if not (is_torch_available() and is_torchsde_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403
else:
from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler

View File

@@ -0,0 +1,447 @@
# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torchsde
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def __init__(self, x, t0, t1, seed=None, **kwargs):
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get("w0", torch.zeros_like(x))
if seed is None:
seed = torch.randint(0, 2**63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
@staticmethod
def sort(a, b):
return (a, b, 1) if a < b else (b, a, -1)
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
return w if self.batched else w[0]
class BrownianTreeNoiseSampler:
"""A noise sampler backed by a torchsde.BrownianTree.
Args:
x (Tensor): The tensor whose shape, device and dtype to use to generate
random samples.
sigma_min (float): The low end of the valid interval.
sigma_max (float): The high end of the valid interval.
seed (int or List[int]): The random seed. If a list of seeds is
supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each
with its own seed.
transform (callable): A function that maps sigma to the sampler's
internal timestep.
"""
def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
self.transform = transform
t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
self.tree = BatchedBrownianTree(x, t0, t1, seed)
def __call__(self, sigma, sigma_next):
t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
"""
Implements Stochastic Sampler (Algorithm 2) from Karras et al. (2022). Based on the original k-diffusion
implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/41b4cb6df0506694a7776af31349acf082bf6091/k_diffusion/sampling.py#L543
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the
noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence
of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf.
noise_sampler_seed (`int`, *optional*, defaults to `None`):
The random seed to use for the noise sampler. If `None`, a random seed will be generated.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 2
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
use_karras_sigmas: Optional[bool] = False,
noise_sampler_seed: Optional[int] = None,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# set all values
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
self.use_karras_sigmas = use_karras_sigmas
self.noise_sampler = None
self.noise_sampler_seed = noise_sampler_seed
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
if self.state_in_first_order:
pos = -1
else:
pos = 0
return indices[pos].item()
def scale_model_input(
self,
sample: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
) -> torch.FloatTensor:
"""
Args:
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
step_index = self.index_for_timestep(timestep)
sigma = self.sigmas[step_index]
sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma
sample = sample / ((sigma_input**2 + 1) ** 0.5)
return sample
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
num_train_timesteps: Optional[int] = None,
):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, optional):
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
self.num_inference_steps = num_inference_steps
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
sigmas = torch.from_numpy(sigmas).to(device=device)
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
timesteps = torch.from_numpy(timesteps)
second_order_timesteps = torch.from_numpy(second_order_timesteps)
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
timesteps[1::2] = second_order_timesteps
if str(device).startswith("mps"):
# mps does not support float64
self.timesteps = timesteps.to(device, dtype=torch.float32)
else:
self.timesteps = timesteps.to(device=device)
# empty first order variables
self.sample = None
self.mid_point_sigma = None
def _second_order_timesteps(self, sigmas, log_sigmas):
def sigma_fn(_t):
return np.exp(-_t)
def t_fn(_sigma):
return -np.log(_sigma)
midpoint_ratio = 0.5
t = t_fn(sigmas)
delta_time = np.diff(t)
t_proposed = t[:-1] + delta_time * midpoint_ratio
sig_proposed = sigma_fn(t_proposed)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed])
return timesteps
# copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
# get log sigma
log_sigma = np.log(sigma)
# get distribution
dists = log_sigma - log_sigmas[:, np.newaxis]
# get sigmas range
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
high_idx = low_idx + 1
low = log_sigmas[low_idx]
high = log_sigmas[high_idx]
# interpolate sigmas
w = (low - log_sigma) / (low - high)
w = np.clip(w, 0, 1)
# transform interpolation to time range
t = (1 - w) * low_idx + w * high_idx
t = t.reshape(sigma.shape)
return t
# copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
sigma_min: float = in_sigmas[-1].item()
sigma_max: float = in_sigmas[0].item()
rho = 7.0 # 7.0 is the value used in the paper
ramp = np.linspace(0, 1, self.num_inference_steps)
min_inv_rho = sigma_min ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
@property
def state_in_first_order(self):
return self.sample is None
def step(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: Union[float, torch.FloatTensor],
sample: Union[torch.FloatTensor, np.ndarray],
return_dict: bool = True,
s_noise: float = 1.0,
) -> Union[SchedulerOutput, Tuple]:
"""
Args:
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
model_output (Union[torch.FloatTensor, np.ndarray]): Direct output from learned diffusion model.
timestep (Union[float, torch.FloatTensor]): Current discrete timestep in the diffusion chain.
sample (Union[torch.FloatTensor, np.ndarray]): Current instance of sample being created by diffusion process.
return_dict (bool, optional): Option for returning tuple rather than SchedulerOutput class. Defaults to True.
s_noise (float, optional): Scaling factor for the noise added to the sample. Defaults to 1.0.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
step_index = self.index_for_timestep(timestep)
# Create a noise sampler if it hasn't been created yet
if self.noise_sampler is None:
min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max()
self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed)
# Define functions to compute sigma and t from each other
def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor:
return _t.neg().exp()
def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor:
return _sigma.log().neg()
if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_next = self.sigmas[step_index + 1]
else:
# 2nd order
sigma = self.sigmas[step_index - 1]
sigma_next = self.sigmas[step_index]
# Set the midpoint and step size for the current step
midpoint_ratio = 0.5
t, t_next = t_fn(sigma), t_fn(sigma_next)
delta_time = t_next - t
t_proposed = t + delta_time * midpoint_ratio
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed)
pred_original_sample = sample - sigma_input * model_output
elif self.config.prediction_type == "v_prediction":
sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed)
pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
sample / (sigma_input**2 + 1)
)
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
if sigma_next == 0:
derivative = (sample - pred_original_sample) / sigma
dt = sigma_next - sigma
prev_sample = sample + derivative * dt
else:
if self.state_in_first_order:
t_next = t_proposed
else:
sample = self.sample
sigma_from = sigma_fn(t)
sigma_to = sigma_fn(t_next)
sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
ancestral_t = t_fn(sigma_down)
prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - (
t - ancestral_t
).expm1() * pred_original_sample
prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up
if self.state_in_first_order:
# store for 2nd order step
self.sample = sample
self.mid_point_sigma = sigma_fn(t_next)
else:
# free for "first order mode"
self.sample = None
self.mid_point_sigma = None
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -70,8 +70,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
`linear` or `scaled_linear`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4

View File

@@ -43,6 +43,7 @@ class KarrasDiffusionSchedulers(Enum):
KDPM2AncestralDiscreteScheduler = 11
DEISMultistepScheduler = 12
UniPCMultistepScheduler = 13
DPMSolverSDEScheduler = 14
@dataclass

View File

@@ -70,6 +70,7 @@ from .import_utils import (
is_tf_available,
is_torch_available,
is_torch_version,
is_torchsde_available,
is_transformers_available,
is_transformers_version,
is_unidecode_available,

View File

@@ -0,0 +1,17 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
class DPMSolverSDEScheduler(metaclass=DummyObject):
_backends = ["torch", "torchsde"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "torchsde"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "torchsde"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "torchsde"])

View File

@@ -287,6 +287,13 @@ try:
except importlib_metadata.PackageNotFoundError:
_bs4_available = False
_torchsde_available = importlib.util.find_spec("torchsde") is not None
try:
_torchsde_version = importlib_metadata.version("torchsde")
logger.debug(f"Successfully imported torchsde version {_torchsde_version}")
except importlib_metadata.PackageNotFoundError:
_torchsde_available = False
def is_torch_available():
return _torch_available
@@ -372,6 +379,10 @@ def is_bs4_available():
return _bs4_available
def is_torchsde_available():
return _torchsde_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -475,6 +486,11 @@ installation section: https://github.com/rspeer/python-ftfy/tree/master#installi
that match your environment. Please note that you may need to restart your runtime after installation.
"""
# docstyle-ignore
TORCHSDE_IMPORT_ERROR = """
{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde`
"""
BACKENDS_MAPPING = OrderedDict(
[
@@ -495,6 +511,7 @@ BACKENDS_MAPPING = OrderedDict(
("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)),
("compel", (_compel_available, COMPEL_IMPORT_ERROR)),
("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
("torchsde", (_torchsde_available, TORCHSDE_IMPORT_ERROR)),
]
)

View File

@@ -26,6 +26,7 @@ from .import_utils import (
is_opencv_available,
is_torch_available,
is_torch_version,
is_torchsde_available,
)
from .logging import get_logger
@@ -216,6 +217,13 @@ def require_note_seq(test_case):
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
def require_torchsde(test_case):
"""
Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
"""
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
if isinstance(arry, str):
# local_path = "/home/patrick_huggingface_co/"

View File

@@ -0,0 +1,156 @@
import torch
from diffusers import DPMSolverSDEScheduler
from diffusers.utils import torch_device
from diffusers.utils.testing_utils import require_torchsde
from .test_schedulers import SchedulerCommonTest
@require_torchsde
class DPMSolverSDESchedulerTest(SchedulerCommonTest):
scheduler_classes = (DPMSolverSDEScheduler,)
num_inference_steps = 10
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1100,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"noise_sampler_seed": 0,
}
config.update(**kwargs)
return config
def test_timesteps(self):
for timesteps in [10, 50, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "scaled_linear"]:
self.check_over_configs(beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["mps"]:
assert abs(result_sum.item() - 167.47821044921875) < 1e-2
assert abs(result_mean.item() - 0.2178705964565277) < 1e-3
else:
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
def test_full_loop_with_v_prediction(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(prediction_type="v_prediction")
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for i, t in enumerate(scheduler.timesteps):
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["mps"]:
assert abs(result_sum.item() - 124.77149200439453) < 1e-2
assert abs(result_mean.item() - 0.16226289014816284) < 1e-3
else:
assert abs(result_sum.item() - 119.8487548828125) < 1e-2
assert abs(result_mean.item() - 0.1560530662536621) < 1e-3
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
model = self.dummy_model()
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["mps"]:
assert abs(result_sum.item() - 167.46957397460938) < 1e-2
assert abs(result_mean.item() - 0.21805934607982635) < 1e-3
else:
assert abs(result_sum.item() - 162.52383422851562) < 1e-2
assert abs(result_mean.item() - 0.211619570851326) < 1e-3
def test_full_loop_device_karras_sigmas(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True)
scheduler.set_timesteps(self.num_inference_steps, device=torch_device)
model = self.dummy_model()
sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma
sample = sample.to(torch_device)
for t in scheduler.timesteps:
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
if torch_device in ["mps"]:
assert abs(result_sum.item() - 176.66974135742188) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2
else:
assert abs(result_sum.item() - 170.3135223388672) < 1e-2
assert abs(result_mean.item() - 0.23003872730981811) < 1e-2