1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Finish scheduler API (#91)

* finish

* up
This commit is contained in:
Patrick von Platen
2022-07-15 15:04:01 +02:00
committed by GitHub
parent 97e1e3ba76
commit f448360bd0
14 changed files with 233 additions and 188 deletions

View File

@@ -438,8 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
def forward(self, sample, step_value, transformer_out=None):
timesteps = step_value
def forward(self, sample, timestep, transformer_out=None):
timesteps = timestep
x = sample
hs = []
emb = self.time_embed(
@@ -530,8 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
resblock_updown=resblock_updown,
)
def forward(self, sample, step_value, low_res=None):
timesteps = step_value
def forward(self, sample, timestep, low_res=None):
timesteps = timestep
x = sample
_, _, new_height, new_width = x.shape
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")

View File

@@ -323,8 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
self.all_modules = nn.ModuleList(modules)
def forward(self, sample, step_value, sigmas=None):
timesteps = step_value
def forward(self, sample, timestep, sigmas=None):
timesteps = timestep
x = sample
# timestep/noise_level embedding; only for continuous training
modules = self.all_modules

View File

@@ -254,7 +254,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
# ======================================
def forward(
self, sample: torch.FloatTensor, step_value: Union[torch.Tensor, float, int]
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
) -> Dict[str, torch.FloatTensor]:
# TODO(PVP) - to delete later at release
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
@@ -263,10 +263,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
self.set_weights()
# ======================================
# 1. time step embeddings
timesteps = step_value
# 1. time step embeddings -> make correct tensor
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
t_emb = get_timestep_embedding(
timesteps,

View File

@@ -22,19 +22,16 @@ from ...pipeline_utils import DiffusionPipeline
class DDIMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
def __init__(self, unet, scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
# Sample gaussian noise to begin loop
@@ -44,34 +41,19 @@ class DDIMPipeline(DiffusionPipeline):
)
image = image.to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# set step values
self.scheduler.set_timesteps(num_inference_steps)
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
for t in tqdm.tqdm(self.scheduler.timesteps):
# 1. predict noise residual
with torch.no_grad():
residual = self.unet(image, inference_step_times[t])
residual = self.unet(image, t)
if isinstance(residual, dict):
residual = residual["sample"]
# 2. predict previous mean of image x_t-1
pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta)
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(residual, t, image, eta)["prev_sample"]
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
return image
return {"sample": image}

View File

@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
class DDPMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
def __init__(self, unet, scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None):
if torch_device is None:
@@ -40,7 +40,7 @@ class DDPMPipeline(DiffusionPipeline):
)
image = image.to(torch_device)
num_prediction_steps = len(self.noise_scheduler)
num_prediction_steps = len(self.scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual
with torch.no_grad():
@@ -50,13 +50,13 @@ class DDPMPipeline(DiffusionPipeline):
residual = residual["sample"]
# 2. predict previous mean of image x_t-1
pred_prev_image = self.noise_scheduler.step(residual, image, t)
pred_prev_image = self.scheduler.step(residual, t, image)["prev_sample"]
# 3. optionally sample variance
variance = 0
if t > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
variance = self.scheduler.get_variance(t).sqrt() * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance

View File

@@ -713,20 +713,20 @@ class GlidePipeline(DiffusionPipeline):
def __init__(
self,
text_unet: GlideTextToImageUNetModel,
text_noise_scheduler: DDPMScheduler,
text_scheduler: DDPMScheduler,
text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer,
upscale_unet: GlideSuperResUNetModel,
upscale_noise_scheduler: DDIMScheduler,
upscale_scheduler: DDIMScheduler,
):
super().__init__()
self.register_modules(
text_unet=text_unet,
text_noise_scheduler=text_noise_scheduler,
text_scheduler=text_scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
upscale_unet=upscale_unet,
upscale_noise_scheduler=upscale_noise_scheduler,
upscale_scheduler=upscale_scheduler,
)
@torch.no_grad()
@@ -777,20 +777,20 @@ class GlidePipeline(DiffusionPipeline):
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
# 3. Run the text2image generation step
num_prediction_steps = len(self.text_noise_scheduler)
num_prediction_steps = len(self.text_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
with torch.no_grad():
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
model_output = text_model_fn(image, time_input, transformer_out)
noise_residual, model_var_values = torch.split(model_output, 3, dim=1)
min_log = self.text_noise_scheduler.get_variance(t, "fixed_small_log")
max_log = self.text_noise_scheduler.get_variance(t, "fixed_large_log")
min_log = self.text_scheduler.get_variance(t, "fixed_small_log")
max_log = self.text_scheduler.get_variance(t, "fixed_large_log")
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
pred_prev_image = self.text_noise_scheduler.step(noise_residual, image, t)
pred_prev_image = self.text_scheduler.step(noise_residual, image, t)
noise = torch.randn(image.shape, generator=generator).to(torch_device)
variance = torch.exp(0.5 * model_log_variance) * noise
@@ -814,7 +814,7 @@ class GlidePipeline(DiffusionPipeline):
).to(torch_device)
image = image * upsample_temp
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
num_trained_timesteps = self.upscale_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
@@ -825,7 +825,7 @@ class GlidePipeline(DiffusionPipeline):
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# 2. predict previous mean of image x_t-1
pred_prev_image = self.upscale_noise_scheduler.step(
pred_prev_image = self.upscale_scheduler.step(
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
)
@@ -833,9 +833,7 @@ class GlidePipeline(DiffusionPipeline):
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(torch_device)
variance = (
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
)
variance = self.upscale_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance

View File

@@ -545,10 +545,10 @@ class LDMBertModel(LDMBertPreTrainedModel):
class LatentDiffusionPipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
@@ -581,7 +581,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)
num_trained_timesteps = self.noise_scheduler.config.timesteps
num_trained_timesteps = self.scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn(
@@ -622,13 +622,13 @@ class LatentDiffusionPipeline(DiffusionPipeline):
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
# 2. predict previous mean of image x_t-1
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
pred_prev_image = self.scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
variance = self.scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance

View File

@@ -6,10 +6,10 @@ from ...pipeline_utils import DiffusionPipeline
class LatentDiffusionUncondPipeline(DiffusionPipeline):
def __init__(self, vqvae, unet, noise_scheduler):
def __init__(self, vqvae, unet, scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, unet=unet, noise_scheduler=noise_scheduler)
scheduler = scheduler.set_format("pt")
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
@@ -28,44 +28,23 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self.unet.to(torch_device)
self.vqvae.to(torch_device)
num_trained_timesteps = self.noise_scheduler.config.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
).to(torch_device)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
self.scheduler.set_timesteps(num_inference_steps)
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
pred_noise_t = self.unet(image, timesteps)
for t in tqdm.tqdm(self.scheduler.timesteps):
residual = self.unet(image, t)
if isinstance(pred_noise_t, dict):
pred_noise_t = pred_noise_t["sample"]
if isinstance(residual, dict):
residual = residual["sample"]
# 2. predict previous mean of image x_t-1
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(residual, t, image, eta)["prev_sample"]
# decode image with vae
image = self.vqvae.decode(image)
return image
return {"sample": image}

View File

@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
class PNDMPipeline(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
def __init__(self, unet, scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
scheduler = scheduler.set_format("pt")
self.register_modules(unet=unet, scheduler=scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
# For more information on the sampling method you can take a look at Algorithm 2 of
@@ -42,7 +42,7 @@ class PNDMPipeline(DiffusionPipeline):
)
image = image.to(torch_device)
prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps)
prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(prk_time_steps))):
t_orig = prk_time_steps[t]
residual = self.unet(image, t_orig)
@@ -50,9 +50,9 @@ class PNDMPipeline(DiffusionPipeline):
if isinstance(residual, dict):
residual = residual["sample"]
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
image = self.scheduler.step_prk(residual, t, image, num_inference_steps)["prev_sample"]
timesteps = self.noise_scheduler.get_time_steps(num_inference_steps)
timesteps = self.scheduler.get_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(timesteps))):
t_orig = timesteps[t]
residual = self.unet(image, t_orig)
@@ -60,6 +60,6 @@ class PNDMPipeline(DiffusionPipeline):
if isinstance(residual, dict):
residual = residual["sample"]
image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps)
image = self.scheduler.step_plms(residual, t, image, num_inference_steps)["prev_sample"]
return image

View File

@@ -16,8 +16,10 @@
# and https://github.com/hojonathanho/diffusion
import math
from typing import Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
@@ -84,14 +86,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.one = np.array(1.0)
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, self.config.timesteps)[::-1].copy()
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def get_variance(self, t, num_inference_steps):
orig_t = self.config.timesteps // num_inference_steps * t
orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
alpha_prod_t = self.alphas_cumprod[orig_t]
alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
@@ -99,7 +103,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance
def step(self, residual, sample, t, num_inference_steps, eta, use_clipped_residual=False):
def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.config.timesteps, self.config.timesteps // self.num_inference_steps)[
::-1
].copy()
self.set_format(tensor_format=self.tensor_format)
def step(
self,
residual: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
eta,
use_clipped_residual=False,
generator=None,
):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
@@ -111,13 +130,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# - pred_sample_direction -> "direction pointingc to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get actual t and t-1
orig_t = self.config.timesteps // num_inference_steps * t
orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.timesteps // self.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[orig_t]
alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
@@ -130,7 +148,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = self.get_variance(t, num_inference_steps)
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
if use_clipped_residual:
@@ -141,9 +159,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
return pred_prev_sample
if eta > 0:
device = residual.device if torch.is_tensor(residual) else "cpu"
noise = torch.randn(residual.shape, generator=generator).to(device)
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
if not torch.is_tensor(residual):
variance = variance.numpy()
prev_sample = prev_sample + variance
return {"prev_sample": prev_sample}
def add_noise(self, original_samples, noise, timesteps):
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5

View File

@@ -15,8 +15,10 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from typing import Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
@@ -112,7 +114,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return variance
def step(self, residual, sample, t, predict_epsilon=True):
def step(
self,
residual: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
predict_epsilon=True,
):
t = timestep
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
@@ -139,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
return pred_prev_sample
return {"prev_sample": pred_prev_sample}
def add_noise(self, original_samples, noise, timesteps):
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5

View File

@@ -15,8 +15,10 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from typing import Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
@@ -126,7 +128,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(f"mode {self.mode} does not exist.")
def step_prk(self, residual, sample, t, num_inference_steps):
def step_prk(
self,
residual: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
):
t = timestep
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
t_orig = prk_time_steps[t // 4 * 4]
@@ -147,9 +156,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
# cur_sample should not be `None`
cur_sample = self.cur_sample if self.cur_sample is not None else sample
return self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)
return {"prev_sample": self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)}
def step_plms(self, residual, sample, t, num_inference_steps):
def step_plms(
self,
residual: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
):
t = timestep
if len(self.ets) < 3:
raise ValueError(
f"{self.__class__} can only be run AFTER scheduler has been run "
@@ -166,7 +182,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
return self.get_prev_sample(sample, t_orig, t_orig_prev, residual)
return {"prev_sample": self.get_prev_sample(sample, t_orig, t_orig_prev, residual)}
def get_prev_sample(self, sample, t_orig, t_orig_prev, residual):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf

View File

@@ -165,7 +165,7 @@ class ModelTesterMixin:
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["sample", "step_value"]
expected_arg_names = ["sample", "timestep"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_model_from_config(self):
@@ -248,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "step_value": time_step}
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
@@ -323,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
return {"sample": noise, "step_value": time_step, "low_res": low_res}
return {"sample": noise, "timestep": time_step, "low_res": low_res}
@property
def input_shape(self):
@@ -414,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device)
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
return {"sample": noise, "step_value": time_step, "transformer_out": emb}
return {"sample": noise, "timestep": time_step, "transformer_out": emb}
@property
def input_shape(self):
@@ -506,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "step_value": time_step}
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
@@ -601,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(torch_device)
return {"sample": noise, "step_value": time_step}
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
@@ -899,8 +899,8 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline.from_pretrained(model_path)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
ddpm.noise_scheduler.num_timesteps = 10
ddpm_from_hub.noise_scheduler.num_timesteps = 10
ddpm.scheduler.num_timesteps = 10
ddpm_from_hub.scheduler.num_timesteps = 10
generator = torch.manual_seed(0)
@@ -915,10 +915,10 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "fusing/ddpm-cifar10"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDPMScheduler.from_config(model_id)
noise_scheduler = noise_scheduler.set_format("pt")
scheduler = DDPMScheduler.from_config(model_id)
scheduler = scheduler.set_format("pt")
ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler)
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
@@ -936,13 +936,12 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "fusing/ddpm-lsun-bedroom-ema"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDIMScheduler.from_config(model_id)
noise_scheduler = noise_scheduler.set_format("pt")
scheduler = DDIMScheduler.from_config(model_id)
ddpm = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)
image = ddpm(generator=generator)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
@@ -957,12 +956,12 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "fusing/ddpm-cifar10"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = DDIMScheduler(tensor_format="pt")
scheduler = DDIMScheduler(tensor_format="pt")
ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0)
image = ddim(generator=generator, eta=0.0)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
@@ -977,9 +976,9 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "fusing/ddpm-cifar10"
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
noise_scheduler = PNDMScheduler(tensor_format="pt")
scheduler = PNDMScheduler(tensor_format="pt")
pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler)
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = pndm(generator=generator)
@@ -1074,7 +1073,7 @@ class PipelineTesterMixin(unittest.TestCase):
ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256", ldm=True)
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5)
image = ldm(generator=generator, num_inference_steps=5)["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()

View File

@@ -68,6 +68,8 @@ class SchedulerCommonTest(unittest.TestCase):
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
sample = self.dummy_sample
@@ -80,8 +82,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
new_scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@@ -89,6 +97,8 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
@@ -101,14 +111,24 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
new_scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
torch.manual_seed(0)
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
@@ -120,14 +140,22 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step(residual, sample, 1, **kwargs)
new_output = new_scheduler.step(residual, sample, 1, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
new_scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
@@ -135,8 +163,13 @@ class SchedulerCommonTest(unittest.TestCase):
sample = self.dummy_sample
residual = 0.1 * sample
output_0 = scheduler.step(residual, sample, 0, **kwargs)
output_1 = scheduler.step(residual, sample, 1, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
@@ -144,6 +177,8 @@ class SchedulerCommonTest(unittest.TestCase):
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
@@ -156,8 +191,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
output = scheduler.step(residual, sample, 1, **kwargs)
output_pt = scheduler_pt.step(residual_pt, sample_pt, 1, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
@@ -226,7 +267,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, sample, t)
pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"]
if t > 0:
noise = self.dummy_sample_deter
@@ -243,7 +284,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDIMScheduler,)
forward_default_kwargs = (("num_inference_steps", 50), ("eta", 0.0))
forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
def get_scheduler_config(self, **kwargs):
config = {
@@ -258,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
return config
def test_timesteps(self):
for timesteps in [1, 5, 100, 1000]:
for timesteps in [100, 500, 1000]:
self.check_over_configs(timesteps=timesteps)
def test_betas(self):
@@ -279,7 +320,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
self.check_over_forward(num_inference_steps=num_inference_steps)
def test_eta(self):
for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
@@ -290,43 +331,34 @@ class DDIMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
assert np.sum(np.abs(scheduler.get_variance(0, 50) - 0.0)) < 1e-5
assert np.sum(np.abs(scheduler.get_variance(21, 50) - 0.14771)) < 1e-5
assert np.sum(np.abs(scheduler.get_variance(49, 50) - 0.32460)) < 1e-5
assert np.sum(np.abs(scheduler.get_variance(0, 1000) - 0.0)) < 1e-5
assert np.sum(np.abs(scheduler.get_variance(487, 1000) - 0.00979)) < 1e-5
assert np.sum(np.abs(scheduler.get_variance(999, 1000) - 0.02)) < 1e-5
assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
assert np.sum(np.abs(scheduler._get_variance(420, 400) - 0.14771)) < 1e-5
assert np.sum(np.abs(scheduler._get_variance(980, 960) - 0.32460)) < 1e-5
assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
assert np.sum(np.abs(scheduler._get_variance(487, 486) - 0.00979)) < 1e-5
assert np.sum(np.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
num_inference_steps, eta = 10, 0.1
num_trained_timesteps = len(scheduler)
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
num_inference_steps, eta = 10, 0.0
model = self.dummy_model()
sample = self.dummy_sample_deter
for t in reversed(range(num_inference_steps)):
residual = model(sample, inference_step_times[t])
scheduler.set_timesteps(num_inference_steps)
for t in scheduler.timesteps:
residual = model(sample, t)
pred_prev_sample = scheduler.step(residual, sample, t, num_inference_steps, eta)
variance = 0
if eta > 0:
noise = self.dummy_sample_deter
variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise
sample = pred_prev_sample + variance
sample = scheduler.step(residual, t, sample, eta)["prev_sample"]
result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample))
assert abs(result_sum.item() - 270.6214) < 1e-2
assert abs(result_mean.item() - 0.3524) < 1e-3
assert abs(result_sum.item() - 172.0067) < 1e-2
assert abs(result_mean.item() - 0.223967) < 1e-3
class PNDMSchedulerTest(SchedulerCommonTest):
@@ -365,8 +397,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode()
output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@@ -392,8 +424,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode()
output = scheduler.step(residual, sample, time_step, **kwargs)
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@@ -445,7 +477,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler.set_plms_mode()
scheduler.step(self.dummy_sample, self.dummy_sample, 1, 50)
scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
@@ -461,14 +493,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
t_orig = prk_time_steps[t]
residual = model(sample, t_orig)
sample = scheduler.step_prk(residual, sample, t, num_inference_steps)
sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"]
timesteps = scheduler.get_time_steps(num_inference_steps)
for t in range(len(timesteps)):
t_orig = timesteps[t]
residual = model(sample, t_orig)
sample = scheduler.step_plms(residual, sample, t, num_inference_steps)
sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"]
result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample))