mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Unify offset configuration in DDIM and PNDM schedulers (#479)
* Unify offset configuration in DDIM and PNDM schedulers * Format Add missing variables * Fix pipeline test * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Default set_alpha_to_one to false * Format * Add tests * Format * add deprecation warning Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
@@ -53,6 +54,21 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
warnings.warn(
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file",
|
||||
DeprecationWarning,
|
||||
)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -217,12 +233,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
||||
extra_set_kwargs = {}
|
||||
if accepts_offset:
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -7,6 +8,7 @@ import torch
|
||||
import PIL
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
@@ -64,6 +66,21 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
warnings.warn(
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file",
|
||||
DeprecationWarning,
|
||||
)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -169,14 +186,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
||||
extra_set_kwargs = {}
|
||||
offset = 0
|
||||
if accepts_offset:
|
||||
offset = 1
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
if isinstance(init_image, PIL.Image.Image):
|
||||
init_image = preprocess(init_image)
|
||||
@@ -190,6 +200,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
init_latents = torch.cat([init_latents] * batch_size)
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -8,6 +9,7 @@ import PIL
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, PNDMScheduler
|
||||
@@ -83,6 +85,21 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
warnings.warn(
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 istead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file",
|
||||
DeprecationWarning,
|
||||
)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -193,19 +210,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
||||
extra_set_kwargs = {}
|
||||
offset = 0
|
||||
if accepts_offset:
|
||||
offset = 1
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# preprocess image
|
||||
if not isinstance(init_image, torch.FloatTensor):
|
||||
init_image = preprocess_image(init_image)
|
||||
init_image.to(self.device)
|
||||
init_image = init_image.to(self.device)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
init_latent_dist = self.vae.encode(init_image).latent_dist
|
||||
@@ -220,7 +230,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
# preprocess mask
|
||||
if not isinstance(mask_image, torch.FloatTensor):
|
||||
mask_image = preprocess_mask(mask_image)
|
||||
mask_image.to(self.device)
|
||||
mask_image = mask_image.to(self.device)
|
||||
mask = torch.cat([mask_image] * batch_size)
|
||||
|
||||
# check sizes
|
||||
@@ -228,6 +238,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
raise ValueError("The mask and init_image should be the same size!")
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
|
||||
@@ -100,12 +100,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
||||
extra_set_kwargs = {}
|
||||
if accepts_offset:
|
||||
extra_set_kwargs["offset"] = 1
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# and https://github.com/hojonathanho/diffusion
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -78,7 +79,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
set_alpha_to_one (`bool`, default `True`):
|
||||
if alpha for final step is 1 or the final alpha of the "non-previous" one.
|
||||
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
||||
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the value of alpha at step 0.
|
||||
steps_offset (`int`, default `0`):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
||||
|
||||
"""
|
||||
@@ -93,6 +100,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
@@ -134,16 +142,26 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
|
||||
def set_timesteps(self, num_inference_steps: int, **kwargs):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
offset (`int`):
|
||||
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
|
||||
"""
|
||||
|
||||
offset = self.config.steps_offset
|
||||
|
||||
if "offset" in kwargs:
|
||||
warnings.warn(
|
||||
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
|
||||
" Please pass `steps_offset` to `__init__` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
offset = kwargs["offset"]
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -74,10 +75,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
|
||||
skip_prk_steps (`bool`):
|
||||
allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
|
||||
before plms steps; defaults to `False`.
|
||||
set_alpha_to_one (`bool`, default `False`):
|
||||
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
||||
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the value of alpha at step 0.
|
||||
steps_offset (`int`, default `0`):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays
|
||||
|
||||
"""
|
||||
|
||||
@@ -89,8 +98,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
tensor_format: str = "pt",
|
||||
skip_prk_steps: bool = False,
|
||||
set_alpha_to_one: bool = False,
|
||||
steps_offset: int = 0,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
@@ -108,6 +119,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
|
||||
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
||||
@@ -122,7 +135,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self._timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||
self._offset = 0
|
||||
self.prk_timesteps = None
|
||||
self.plms_timesteps = None
|
||||
self.timesteps = None
|
||||
@@ -130,23 +142,31 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
|
||||
def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
offset (`int`):
|
||||
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
|
||||
"""
|
||||
|
||||
offset = self.config.steps_offset
|
||||
|
||||
if "offset" in kwargs:
|
||||
warnings.warn(
|
||||
"`offset` is deprecated as an input argument to `set_timesteps` and will be removed in v0.4.0."
|
||||
" Please pass `steps_offset` to `__init__` instead."
|
||||
)
|
||||
|
||||
offset = kwargs["offset"]
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist()
|
||||
self._offset = offset
|
||||
self._timesteps = np.array([t + self._offset for t in self._timesteps])
|
||||
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()
|
||||
self._timesteps += offset
|
||||
|
||||
if self.config.skip_prk_steps:
|
||||
# for some models like stable diffusion the prk steps can/should be skipped to
|
||||
@@ -231,7 +251,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
diff_to_prev = 0 if self.counter % 2 else self.config.num_train_timesteps // self.num_inference_steps // 2
|
||||
prev_timestep = max(timestep - diff_to_prev, self.prk_timesteps[-1])
|
||||
prev_timestep = timestep - diff_to_prev
|
||||
timestep = self.prk_timesteps[self.counter // 4 * 4]
|
||||
|
||||
if self.counter % 4 == 0:
|
||||
@@ -293,7 +313,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"for more information."
|
||||
)
|
||||
|
||||
prev_timestep = max(timestep - self.config.num_train_timesteps // self.num_inference_steps, 0)
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
if self.counter != 1:
|
||||
self.ets.append(model_output)
|
||||
@@ -323,7 +343,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
|
||||
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
|
||||
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
|
||||
# this function computes x_(t−δ) using the formula of (9)
|
||||
# Note that x_t needs to be added to both sides of the equation
|
||||
@@ -336,8 +356,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# sample -> x_t
|
||||
# model_output -> e_θ(x_t, t)
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = self.alphas_cumprod[timestep + 1 - self._offset]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev + 1 - self._offset]
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
|
||||
@@ -357,10 +357,38 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def full_loop(self, **config):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps, eta = 10, 0.0
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in scheduler.timesteps:
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step(residual, t, sample, eta).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 500, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_steps_offset(self):
|
||||
for steps_offset in [0, 1]:
|
||||
self.check_over_configs(steps_offset=steps_offset)
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(steps_offset=1)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(5)
|
||||
assert torch.equal(scheduler.timesteps, torch.tensor([801, 601, 401, 201, 1]))
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
@@ -398,20 +426,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
assert torch.sum(torch.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps, eta = 10, 0.0
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
for t in scheduler.timesteps:
|
||||
residual = model(sample, t)
|
||||
|
||||
sample = scheduler.step(residual, t, sample, eta).prev_sample
|
||||
sample = self.full_loop()
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
@@ -419,6 +434,24 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
assert abs(result_sum.item() - 172.0067) < 1e-2
|
||||
assert abs(result_mean.item() - 0.223967) < 1e-3
|
||||
|
||||
def test_full_loop_with_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 149.8295) < 1e-2
|
||||
assert abs(result_mean.item() - 0.1951) < 1e-3
|
||||
|
||||
def test_full_loop_with_no_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 149.0784) < 1e-2
|
||||
assert abs(result_mean.item() - 0.1941) < 1e-3
|
||||
|
||||
|
||||
class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (PNDMScheduler,)
|
||||
@@ -503,6 +536,26 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def full_loop(self, **config):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.prk_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_prk(residual, t, sample).prev_sample
|
||||
|
||||
for i, t in enumerate(scheduler.plms_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_plms(residual, t, sample).prev_sample
|
||||
|
||||
return sample
|
||||
|
||||
def test_pytorch_equal_numpy(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
@@ -606,8 +659,23 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_steps_offset(self):
|
||||
for steps_offset in [0, 1]:
|
||||
self.check_over_configs(steps_offset=steps_offset)
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(steps_offset=1)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(10)
|
||||
assert torch.equal(
|
||||
scheduler.timesteps,
|
||||
torch.tensor(
|
||||
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
|
||||
),
|
||||
)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
@@ -620,7 +688,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
self.check_over_forward(num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_pow_of_3_inference_steps(self):
|
||||
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
|
||||
@@ -648,28 +716,30 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(scheduler.prk_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_prk(residual, i, sample).prev_sample
|
||||
|
||||
for i, t in enumerate(scheduler.plms_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_plms(residual, i, sample).prev_sample
|
||||
|
||||
sample = self.full_loop()
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 428.8788) < 1e-2
|
||||
assert abs(result_mean.item() - 0.5584) < 1e-3
|
||||
assert abs(result_sum.item() - 198.1318) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2580) < 1e-3
|
||||
|
||||
def test_full_loop_with_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=True, beta_start=0.01)
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 230.0399) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2995) < 1e-3
|
||||
|
||||
def test_full_loop_with_no_set_alpha_to_one(self):
|
||||
# We specify different beta, so that the first alpha is 0.99
|
||||
sample = self.full_loop(set_alpha_to_one=False, beta_start=0.01)
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 186.9482) < 1e-2
|
||||
assert abs(result_mean.item() - 0.2434) < 1e-3
|
||||
|
||||
|
||||
class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user