1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Unify offset configuration in DDIM and PNDM schedulers (#479)

* Unify offset configuration in DDIM and PNDM schedulers

* Format

Add missing variables

* Fix pipeline test

* Update src/diffusers/schedulers/scheduling_ddim.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Default set_alpha_to_one to false

* Format

* Add tests

* Format

* add deprecation warning

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Jonatan Kłosko
2022-09-17 14:07:43 +02:00
committed by GitHub
parent 9e439d8c60
commit d7dcba4a13
7 changed files with 219 additions and 83 deletions

View File

@@ -6,6 +6,7 @@ import torch
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
@@ -53,6 +54,21 @@ class StableDiffusionPipeline(DiffusionPipeline):
):
super().__init__()
scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -217,12 +233,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latents = latents.to(self.device)
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
self.scheduler.set_timesteps(num_inference_steps)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):

View File

@@ -1,4 +1,5 @@
import inspect
import warnings
from typing import List, Optional, Union
import numpy as np
@@ -7,6 +8,7 @@ import torch
import PIL
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
@@ -64,6 +66,21 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
):
super().__init__()
scheduler = scheduler.set_format("pt")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -169,14 +186,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
self.scheduler.set_timesteps(num_inference_steps)
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)
@@ -190,6 +200,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_latents = torch.cat([init_latents] * batch_size)
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):

View File

@@ -1,4 +1,5 @@
import inspect
import warnings
from typing import List, Optional, Union
import numpy as np
@@ -8,6 +9,7 @@ import PIL
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler
@@ -83,6 +85,21 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
super().__init__()
scheduler = scheduler.set_format("pt")
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
warnings.warn(
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file",
DeprecationWarning,
)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -193,19 +210,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
offset = 0
if accepts_offset:
offset = 1
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
self.scheduler.set_timesteps(num_inference_steps)
# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
init_image.to(self.device)
init_image = init_image.to(self.device)
# encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image).latent_dist
@@ -220,7 +230,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image.to(self.device)
mask_image = mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size)
# check sizes
@@ -228,6 +238,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
raise ValueError("The mask and init_image should be the same size!")
# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
timesteps = self.scheduler.timesteps[-init_timestep]

View File

@@ -100,12 +100,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
extra_set_kwargs = {}
if accepts_offset:
extra_set_kwargs["offset"] = 1
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
self.scheduler.set_timesteps(num_inference_steps)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):

View File

@@ -16,6 +16,7 @@
# and https://github.com/hojonathanho/diffusion
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
@@ -78,7 +79,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
set_alpha_to_one (`bool`, default `True`):
if alpha for final step is 1 or the final alpha of the "non-previous" one.
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
"""
@@ -93,6 +100,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
tensor_format: str = "pt",
):
if trained_betas is not None:
@@ -134,16 +142,26 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
def set_timesteps(self, num_inference_steps: int, **kwargs):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
offset = self.config.steps_offset
if "offset" in kwargs:
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead.",
DeprecationWarning,
)
offset = kwargs["offset"]
self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio

View File

@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import warnings
from typing import Optional, Tuple, Union
import numpy as np
@@ -74,10 +75,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
skip_prk_steps (`bool`):
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
before plms steps; defaults to `False`.
set_alpha_to_one (`bool`, default `False`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the value of alpha at step 0.
steps_offset (`int`, default `0`):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
"""
@@ -89,8 +98,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
tensor_format: str = "pt",
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
tensor_format: str = "pt",
):
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
@@ -108,6 +119,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2.
@@ -122,7 +135,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# setable values
self.num_inference_steps = None
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self._offset = 0
self.prk_timesteps = None
self.plms_timesteps = None
self.timesteps = None
@@ -130,23 +142,31 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
offset (`int`):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
offset = self.config.steps_offset
if "offset" in kwargs:
warnings.warn(
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
" Please pass `steps_offset` to `__init__` instead."
)
offset = kwargs["offset"]
self.num_inference_steps = num_inference_steps
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist()
self._offset = offset
self._timesteps = np.array([t + self._offset for t in self._timesteps])
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
self._timesteps += offset
if self.config.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
@@ -231,7 +251,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
)
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
prev_timestep = timestep - diff_to_prev
timestep = self.prk_timesteps[self.counter // 4 * 4]
if self.counter % 4 == 0:
@@ -293,7 +313,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information."
)
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
if self.counter != 1:
self.ets.append(model_output)
@@ -323,7 +343,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(tδ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
@@ -336,8 +356,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(tδ)
alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

View File

@@ -357,10 +357,38 @@ class DDIMSchedulerTest(SchedulerCommonTest):
config.update(**kwargs)
return config
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps, eta = 10, 0.0
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta).prev_sample
return sample
def test_timesteps(self):
for timesteps in [100, 500, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_steps_offset(self):
for steps_offset in [0, 1]:
self.check_over_configs(steps_offset=steps_offset)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(5)
assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1]))
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
@@ -398,20 +426,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps, eta = 10, 0.0
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta).prev_sample
sample = self.full_loop()
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
@@ -419,6 +434,24 @@ class DDIMSchedulerTest(SchedulerCommonTest):
assert abs(result_sum.item() - 172.0067) < 1e-2
assert abs(result_mean.item() - 0.223967) < 1e-3
def test_full_loop_with_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 149.8295) < 1e-2
assert abs(result_mean.item() - 0.1951) < 1e-3
def test_full_loop_with_no_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 149.0784) < 1e-2
assert abs(result_mean.item() - 0.1941) < 1e-3
class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (PNDMScheduler,)
@@ -503,6 +536,26 @@ class PNDMSchedulerTest(SchedulerCommonTest):
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def full_loop(self, **config):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.prk_timesteps):
residual = model(sample, t)
sample = scheduler.step_prk(residual, t, sample).prev_sample
for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t)
sample = scheduler.step_plms(residual, t, sample).prev_sample
return sample
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
@@ -606,8 +659,23 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for timesteps in [100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_steps_offset(self):
for steps_offset in [0, 1]:
self.check_over_configs(steps_offset=steps_offset)
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(steps_offset=1)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(10)
assert torch.equal(
scheduler.timesteps,
torch.tensor(
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
),
)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
@@ -620,7 +688,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
self.check_over_forward(num_inference_steps=num_inference_steps)
def test_pow_of_3_inference_steps(self):
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
@@ -648,28 +716,30 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(scheduler.prk_timesteps):
residual = model(sample, t)
sample = scheduler.step_prk(residual, i, sample).prev_sample
for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t)
sample = scheduler.step_plms(residual, i, sample).prev_sample
sample = self.full_loop()
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 428.8788) < 1e-2
assert abs(result_mean.item() - 0.5584) < 1e-3
assert abs(result_sum.item() - 198.1318) < 1e-2
assert abs(result_mean.item() - 0.2580) < 1e-3
def test_full_loop_with_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 230.0399) < 1e-2
assert abs(result_mean.item() - 0.2995) < 1e-3
def test_full_loop_with_no_set_alpha_to_one(self):
# We specify different beta, so that the first alpha is 0.99
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 186.9482) < 1e-2
assert abs(result_mean.item() - 0.2434) < 1e-3
class ScoreSdeVeSchedulerTest(unittest.TestCase):