mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
committed by
GitHub
parent
97e1e3ba76
commit
f448360bd0
@@ -438,8 +438,8 @@ class GlideTextToImageUNetModel(GlideUNetModel):
|
||||
|
||||
self.transformer_proj = nn.Linear(transformer_dim, self.model_channels * 4)
|
||||
|
||||
def forward(self, sample, step_value, transformer_out=None):
|
||||
timesteps = step_value
|
||||
def forward(self, sample, timestep, transformer_out=None):
|
||||
timesteps = timestep
|
||||
x = sample
|
||||
hs = []
|
||||
emb = self.time_embed(
|
||||
@@ -530,8 +530,8 @@ class GlideSuperResUNetModel(GlideUNetModel):
|
||||
resblock_updown=resblock_updown,
|
||||
)
|
||||
|
||||
def forward(self, sample, step_value, low_res=None):
|
||||
timesteps = step_value
|
||||
def forward(self, sample, timestep, low_res=None):
|
||||
timesteps = timestep
|
||||
x = sample
|
||||
_, _, new_height, new_width = x.shape
|
||||
upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
|
||||
|
||||
@@ -323,8 +323,8 @@ class NCSNpp(ModelMixin, ConfigMixin):
|
||||
|
||||
self.all_modules = nn.ModuleList(modules)
|
||||
|
||||
def forward(self, sample, step_value, sigmas=None):
|
||||
timesteps = step_value
|
||||
def forward(self, sample, timestep, sigmas=None):
|
||||
timesteps = timestep
|
||||
x = sample
|
||||
# timestep/noise_level embedding; only for continuous training
|
||||
modules = self.all_modules
|
||||
|
||||
@@ -254,7 +254,7 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
# ======================================
|
||||
|
||||
def forward(
|
||||
self, sample: torch.FloatTensor, step_value: Union[torch.Tensor, float, int]
|
||||
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
# TODO(PVP) - to delete later at release
|
||||
# IMPORTANT: NOT RELEVANT WHEN REVIEWING API
|
||||
@@ -263,10 +263,12 @@ class UNetUnconditionalModel(ModelMixin, ConfigMixin):
|
||||
self.set_weights()
|
||||
# ======================================
|
||||
|
||||
# 1. time step embeddings
|
||||
timesteps = step_value
|
||||
# 1. time step embeddings -> make correct tensor
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
|
||||
@@ -22,19 +22,16 @@ from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class DDIMPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
num_trained_timesteps = self.noise_scheduler.config.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
self.unet.to(torch_device)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
@@ -44,34 +41,19 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_image -> f_theta(x_t, t) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
for t in tqdm.tqdm(self.scheduler.timesteps):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
residual = self.unet(image, inference_step_times[t])
|
||||
residual = self.unet(image, t)
|
||||
|
||||
if isinstance(residual, dict):
|
||||
residual = residual["sample"]
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta)
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(residual, t, image, eta)["prev_sample"]
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
||||
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
return image
|
||||
return {"sample": image}
|
||||
|
||||
@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class DDPMPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None):
|
||||
if torch_device is None:
|
||||
@@ -40,7 +40,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
num_prediction_steps = len(self.noise_scheduler)
|
||||
num_prediction_steps = len(self.scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
@@ -50,13 +50,13 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
residual = residual["sample"]
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.noise_scheduler.step(residual, image, t)
|
||||
pred_prev_image = self.scheduler.step(residual, t, image)["prev_sample"]
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
||||
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
|
||||
variance = self.scheduler.get_variance(t).sqrt() * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
@@ -713,20 +713,20 @@ class GlidePipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
text_unet: GlideTextToImageUNetModel,
|
||||
text_noise_scheduler: DDPMScheduler,
|
||||
text_scheduler: DDPMScheduler,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: GPT2Tokenizer,
|
||||
upscale_unet: GlideSuperResUNetModel,
|
||||
upscale_noise_scheduler: DDIMScheduler,
|
||||
upscale_scheduler: DDIMScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
text_unet=text_unet,
|
||||
text_noise_scheduler=text_noise_scheduler,
|
||||
text_scheduler=text_scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
upscale_unet=upscale_unet,
|
||||
upscale_noise_scheduler=upscale_noise_scheduler,
|
||||
upscale_scheduler=upscale_scheduler,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -777,20 +777,20 @@ class GlidePipeline(DiffusionPipeline):
|
||||
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
||||
|
||||
# 3. Run the text2image generation step
|
||||
num_prediction_steps = len(self.text_noise_scheduler)
|
||||
num_prediction_steps = len(self.text_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
with torch.no_grad():
|
||||
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
||||
model_output = text_model_fn(image, time_input, transformer_out)
|
||||
noise_residual, model_var_values = torch.split(model_output, 3, dim=1)
|
||||
|
||||
min_log = self.text_noise_scheduler.get_variance(t, "fixed_small_log")
|
||||
max_log = self.text_noise_scheduler.get_variance(t, "fixed_large_log")
|
||||
min_log = self.text_scheduler.get_variance(t, "fixed_small_log")
|
||||
max_log = self.text_scheduler.get_variance(t, "fixed_large_log")
|
||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||
frac = (model_var_values + 1) / 2
|
||||
model_log_variance = frac * max_log + (1 - frac) * min_log
|
||||
|
||||
pred_prev_image = self.text_noise_scheduler.step(noise_residual, image, t)
|
||||
pred_prev_image = self.text_scheduler.step(noise_residual, image, t)
|
||||
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
||||
variance = torch.exp(0.5 * model_log_variance) * noise
|
||||
|
||||
@@ -814,7 +814,7 @@ class GlidePipeline(DiffusionPipeline):
|
||||
).to(torch_device)
|
||||
image = image * upsample_temp
|
||||
|
||||
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
|
||||
num_trained_timesteps = self.upscale_scheduler.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
|
||||
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
||||
@@ -825,7 +825,7 @@ class GlidePipeline(DiffusionPipeline):
|
||||
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.upscale_noise_scheduler.step(
|
||||
pred_prev_image = self.upscale_scheduler.step(
|
||||
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
|
||||
)
|
||||
|
||||
@@ -833,9 +833,7 @@ class GlidePipeline(DiffusionPipeline):
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
||||
variance = (
|
||||
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
|
||||
)
|
||||
variance = self.upscale_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
@@ -545,10 +545,10 @@ class LDMBertModel(LDMBertPreTrainedModel):
|
||||
|
||||
|
||||
class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
def __init__(self, vqvae, bert, tokenizer, unet, noise_scheduler):
|
||||
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, noise_scheduler=noise_scheduler)
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(vqvae=vqvae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
@@ -581,7 +581,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
||||
text_embedding = self.bert(text_input.input_ids)
|
||||
|
||||
num_trained_timesteps = self.noise_scheduler.config.timesteps
|
||||
num_trained_timesteps = self.scheduler.config.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
image = torch.randn(
|
||||
@@ -622,13 +622,13 @@ class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
pred_noise_t = pred_noise_t_uncond + guidance_scale * (pred_noise_t - pred_noise_t_uncond)
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
|
||||
pred_prev_image = self.scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
||||
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
|
||||
variance = self.scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
@@ -6,10 +6,10 @@ from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||
def __init__(self, vqvae, unet, noise_scheduler):
|
||||
def __init__(self, vqvae, unet, scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(vqvae=vqvae, unet=unet, noise_scheduler=noise_scheduler)
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(vqvae=vqvae, unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
@@ -28,44 +28,23 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
|
||||
num_trained_timesteps = self.noise_scheduler.config.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
generator=generator,
|
||||
).to(torch_device)
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_image -> f_theta(x_t, t) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
# 1. predict noise residual
|
||||
timesteps = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
pred_noise_t = self.unet(image, timesteps)
|
||||
for t in tqdm.tqdm(self.scheduler.timesteps):
|
||||
residual = self.unet(image, t)
|
||||
|
||||
if isinstance(pred_noise_t, dict):
|
||||
pred_noise_t = pred_noise_t["sample"]
|
||||
if isinstance(residual, dict):
|
||||
residual = residual["sample"]
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.noise_scheduler.step(pred_noise_t, image, t, num_inference_steps, eta)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
||||
variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
|
||||
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(residual, t, image, eta)["prev_sample"]
|
||||
|
||||
# decode image with vae
|
||||
image = self.vqvae.decode(image)
|
||||
return image
|
||||
return {"sample": image}
|
||||
|
||||
@@ -22,10 +22,10 @@ from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
class PNDMPipeline(DiffusionPipeline):
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
|
||||
# For more information on the sampling method you can take a look at Algorithm 2 of
|
||||
@@ -42,7 +42,7 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
prk_time_steps = self.noise_scheduler.get_prk_time_steps(num_inference_steps)
|
||||
prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps)
|
||||
for t in tqdm.tqdm(range(len(prk_time_steps))):
|
||||
t_orig = prk_time_steps[t]
|
||||
residual = self.unet(image, t_orig)
|
||||
@@ -50,9 +50,9 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
if isinstance(residual, dict):
|
||||
residual = residual["sample"]
|
||||
|
||||
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
|
||||
image = self.scheduler.step_prk(residual, t, image, num_inference_steps)["prev_sample"]
|
||||
|
||||
timesteps = self.noise_scheduler.get_time_steps(num_inference_steps)
|
||||
timesteps = self.scheduler.get_time_steps(num_inference_steps)
|
||||
for t in tqdm.tqdm(range(len(timesteps))):
|
||||
t_orig = timesteps[t]
|
||||
residual = self.unet(image, t_orig)
|
||||
@@ -60,6 +60,6 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
if isinstance(residual, dict):
|
||||
residual = residual["sample"]
|
||||
|
||||
image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps)
|
||||
image = self.scheduler.step_plms(residual, t, image, num_inference_steps)["prev_sample"]
|
||||
|
||||
return image
|
||||
|
||||
@@ -16,8 +16,10 @@
|
||||
# and https://github.com/hojonathanho/diffusion
|
||||
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
@@ -84,14 +86,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
self.one = np.array(1.0)
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = np.arange(0, self.config.timesteps)[::-1].copy()
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def get_variance(self, t, num_inference_steps):
|
||||
orig_t = self.config.timesteps // num_inference_steps * t
|
||||
orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
|
||||
|
||||
alpha_prod_t = self.alphas_cumprod[orig_t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -99,7 +103,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def step(self, residual, sample, t, num_inference_steps, eta, use_clipped_residual=False):
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(0, self.config.timesteps, self.config.timesteps // self.num_inference_steps)[
|
||||
::-1
|
||||
].copy()
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def step(
|
||||
self,
|
||||
residual: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
eta,
|
||||
use_clipped_residual=False,
|
||||
generator=None,
|
||||
):
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
@@ -111,13 +130,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# - pred_sample_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_sample -> "x_t-1"
|
||||
|
||||
# 1. get actual t and t-1
|
||||
orig_t = self.config.timesteps // num_inference_steps * t
|
||||
orig_prev_t = self.config.timesteps // num_inference_steps * (t - 1) if t > 0 else -1
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - self.config.timesteps // self.num_inference_steps
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[orig_t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[orig_prev_t] if orig_prev_t >= 0 else self.one
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
@@ -130,7 +148,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = self.get_variance(t, num_inference_steps)
|
||||
variance = self._get_variance(timestep, prev_timestep)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
if use_clipped_residual:
|
||||
@@ -141,9 +159,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
|
||||
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
|
||||
return pred_prev_sample
|
||||
if eta > 0:
|
||||
device = residual.device if torch.is_tensor(residual) else "cpu"
|
||||
noise = torch.randn(residual.shape, generator=generator).to(device)
|
||||
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * noise
|
||||
|
||||
if not torch.is_tensor(residual):
|
||||
variance = variance.numpy()
|
||||
|
||||
prev_sample = prev_sample + variance
|
||||
|
||||
return {"prev_sample": prev_sample}
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
|
||||
@@ -15,8 +15,10 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
@@ -112,7 +114,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def step(self, residual, sample, t, predict_epsilon=True):
|
||||
def step(
|
||||
self,
|
||||
residual: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
predict_epsilon=True,
|
||||
):
|
||||
t = timestep
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
@@ -139,7 +148,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
||||
|
||||
return pred_prev_sample
|
||||
return {"prev_sample": pred_prev_sample}
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
|
||||
@@ -15,8 +15,10 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
@@ -126,7 +128,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
raise ValueError(f"mode {self.mode} does not exist.")
|
||||
|
||||
def step_prk(self, residual, sample, t, num_inference_steps):
|
||||
def step_prk(
|
||||
self,
|
||||
residual: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
num_inference_steps,
|
||||
):
|
||||
t = timestep
|
||||
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
|
||||
|
||||
t_orig = prk_time_steps[t // 4 * 4]
|
||||
@@ -147,9 +156,16 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# cur_sample should not be `None`
|
||||
cur_sample = self.cur_sample if self.cur_sample is not None else sample
|
||||
|
||||
return self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)
|
||||
return {"prev_sample": self.get_prev_sample(cur_sample, t_orig, t_orig_prev, residual)}
|
||||
|
||||
def step_plms(self, residual, sample, t, num_inference_steps):
|
||||
def step_plms(
|
||||
self,
|
||||
residual: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
num_inference_steps,
|
||||
):
|
||||
t = timestep
|
||||
if len(self.ets) < 3:
|
||||
raise ValueError(
|
||||
f"{self.__class__} can only be run AFTER scheduler has been run "
|
||||
@@ -166,7 +182,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
|
||||
|
||||
return self.get_prev_sample(sample, t_orig, t_orig_prev, residual)
|
||||
return {"prev_sample": self.get_prev_sample(sample, t_orig, t_orig_prev, residual)}
|
||||
|
||||
def get_prev_sample(self, sample, t_orig, t_orig_prev, residual):
|
||||
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
|
||||
|
||||
@@ -165,7 +165,7 @@ class ModelTesterMixin:
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
expected_arg_names = ["sample", "step_value"]
|
||||
expected_arg_names = ["sample", "timestep"]
|
||||
self.assertListEqual(arg_names[:2], expected_arg_names)
|
||||
|
||||
def test_model_from_config(self):
|
||||
@@ -248,7 +248,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return {"sample": noise, "step_value": time_step}
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
@@ -323,7 +323,7 @@ class GlideSuperResUNetTests(ModelTesterMixin, unittest.TestCase):
|
||||
low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device)
|
||||
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
|
||||
|
||||
return {"sample": noise, "step_value": time_step, "low_res": low_res}
|
||||
return {"sample": noise, "timestep": time_step, "low_res": low_res}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
@@ -414,7 +414,7 @@ class GlideTextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device)
|
||||
time_step = torch.tensor([10] * noise.shape[0], device=torch_device)
|
||||
|
||||
return {"sample": noise, "step_value": time_step, "transformer_out": emb}
|
||||
return {"sample": noise, "timestep": time_step, "transformer_out": emb}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
@@ -506,7 +506,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return {"sample": noise, "step_value": time_step}
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
@@ -601,7 +601,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(torch_device)
|
||||
|
||||
return {"sample": noise, "step_value": time_step}
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
@@ -899,8 +899,8 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
ddpm = DDPMPipeline.from_pretrained(model_path)
|
||||
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path)
|
||||
|
||||
ddpm.noise_scheduler.num_timesteps = 10
|
||||
ddpm_from_hub.noise_scheduler.num_timesteps = 10
|
||||
ddpm.scheduler.num_timesteps = 10
|
||||
ddpm_from_hub.scheduler.num_timesteps = 10
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
@@ -915,10 +915,10 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
|
||||
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
|
||||
noise_scheduler = DDPMScheduler.from_config(model_id)
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
scheduler = DDPMScheduler.from_config(model_id)
|
||||
scheduler = scheduler.set_format("pt")
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, noise_scheduler=noise_scheduler)
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator)
|
||||
@@ -936,13 +936,12 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
model_id = "fusing/ddpm-lsun-bedroom-ema"
|
||||
|
||||
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
|
||||
noise_scheduler = DDIMScheduler.from_config(model_id)
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
scheduler = DDIMScheduler.from_config(model_id)
|
||||
|
||||
ddpm = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator)
|
||||
image = ddpm(generator=generator)["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
@@ -957,12 +956,12 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
|
||||
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
|
||||
noise_scheduler = DDIMScheduler(tensor_format="pt")
|
||||
scheduler = DDIMScheduler(tensor_format="pt")
|
||||
|
||||
ddim = DDIMPipeline(unet=unet, noise_scheduler=noise_scheduler)
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddim(generator=generator, eta=0.0)
|
||||
image = ddim(generator=generator, eta=0.0)["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
@@ -977,9 +976,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
model_id = "fusing/ddpm-cifar10"
|
||||
|
||||
unet = UNetUnconditionalModel.from_pretrained(model_id, ddpm=True)
|
||||
noise_scheduler = PNDMScheduler(tensor_format="pt")
|
||||
scheduler = PNDMScheduler(tensor_format="pt")
|
||||
|
||||
pndm = PNDMPipeline(unet=unet, noise_scheduler=noise_scheduler)
|
||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||
generator = torch.manual_seed(0)
|
||||
image = pndm(generator=generator)
|
||||
|
||||
@@ -1074,7 +1073,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
ldm = LatentDiffusionUncondPipeline.from_pretrained("fusing/latent-diffusion-celeba-256", ldm=True)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=5)
|
||||
image = ldm(generator=generator, num_inference_steps=5)["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
|
||||
|
||||
@@ -68,6 +68,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
sample = self.dummy_sample
|
||||
@@ -80,8 +82,14 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step(residual, sample, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
@@ -89,6 +97,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
@@ -101,14 +111,24 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step(residual, sample, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
torch.manual_seed(0)
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
torch.manual_seed(0)
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
@@ -120,14 +140,22 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step(residual, sample, 1, **kwargs)
|
||||
new_output = new_scheduler.step(residual, sample, 1, **kwargs)
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
new_scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
@@ -135,8 +163,13 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
output_0 = scheduler.step(residual, sample, 0, **kwargs)
|
||||
output_1 = scheduler.step(residual, sample, 1, **kwargs)
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
@@ -144,6 +177,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
def test_pytorch_equal_numpy(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
@@ -156,8 +191,14 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
|
||||
output = scheduler.step(residual, sample, 1, **kwargs)
|
||||
output_pt = scheduler_pt.step(residual_pt, sample_pt, 1, **kwargs)
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler_pt.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
@@ -226,7 +267,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
residual = model(sample, t)
|
||||
|
||||
# 2. predict previous mean of sample x_t-1
|
||||
pred_prev_sample = scheduler.step(residual, sample, t)
|
||||
pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"]
|
||||
|
||||
if t > 0:
|
||||
noise = self.dummy_sample_deter
|
||||
@@ -243,7 +284,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DDIMScheduler,)
|
||||
forward_default_kwargs = (("num_inference_steps", 50), ("eta", 0.0))
|
||||
forward_default_kwargs = (("eta", 0.0), ("num_inference_steps", 50))
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
@@ -258,7 +299,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
return config
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [1, 5, 100, 1000]:
|
||||
for timesteps in [100, 500, 1000]:
|
||||
self.check_over_configs(timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
@@ -279,7 +320,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
self.check_over_forward(num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_eta(self):
|
||||
for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
|
||||
@@ -290,43 +331,34 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
assert np.sum(np.abs(scheduler.get_variance(0, 50) - 0.0)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(21, 50) - 0.14771)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(49, 50) - 0.32460)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(0, 1000) - 0.0)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(487, 1000) - 0.00979)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler.get_variance(999, 1000) - 0.02)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler._get_variance(420, 400) - 0.14771)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler._get_variance(980, 960) - 0.32460)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler._get_variance(0, 0) - 0.0)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler._get_variance(487, 486) - 0.00979)) < 1e-5
|
||||
assert np.sum(np.abs(scheduler._get_variance(999, 998) - 0.02)) < 1e-5
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
num_inference_steps, eta = 10, 0.1
|
||||
num_trained_timesteps = len(scheduler)
|
||||
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
|
||||
num_inference_steps, eta = 10, 0.0
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
|
||||
for t in reversed(range(num_inference_steps)):
|
||||
residual = model(sample, inference_step_times[t])
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
for t in scheduler.timesteps:
|
||||
residual = model(sample, t)
|
||||
|
||||
pred_prev_sample = scheduler.step(residual, sample, t, num_inference_steps, eta)
|
||||
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = self.dummy_sample_deter
|
||||
variance = scheduler.get_variance(t, num_inference_steps) ** (0.5) * eta * noise
|
||||
|
||||
sample = pred_prev_sample + variance
|
||||
sample = scheduler.step(residual, t, sample, eta)["prev_sample"]
|
||||
|
||||
result_sum = np.sum(np.abs(sample))
|
||||
result_mean = np.mean(np.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 270.6214) < 1e-2
|
||||
assert abs(result_mean.item() - 0.3524) < 1e-3
|
||||
assert abs(result_sum.item() - 172.0067) < 1e-2
|
||||
assert abs(result_mean.item() - 0.223967) < 1e-3
|
||||
|
||||
|
||||
class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
@@ -365,8 +397,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
|
||||
output = scheduler.step(residual, sample, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
@@ -392,8 +424,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
|
||||
output = scheduler.step(residual, sample, time_step, **kwargs)
|
||||
new_output = new_scheduler.step(residual, sample, time_step, **kwargs)
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
@@ -445,7 +477,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
scheduler.step(self.dummy_sample, self.dummy_sample, 1, 50)
|
||||
scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
@@ -461,14 +493,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
t_orig = prk_time_steps[t]
|
||||
residual = model(sample, t_orig)
|
||||
|
||||
sample = scheduler.step_prk(residual, sample, t, num_inference_steps)
|
||||
sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"]
|
||||
|
||||
timesteps = scheduler.get_time_steps(num_inference_steps)
|
||||
for t in range(len(timesteps)):
|
||||
t_orig = timesteps[t]
|
||||
residual = model(sample, t_orig)
|
||||
|
||||
sample = scheduler.step_plms(residual, sample, t, num_inference_steps)
|
||||
sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"]
|
||||
|
||||
result_sum = np.sum(np.abs(sample))
|
||||
result_mean = np.mean(np.abs(sample))
|
||||
|
||||
Reference in New Issue
Block a user