1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Patrick von Platen
2022-07-20 21:02:43 +00:00
13 changed files with 372 additions and 302 deletions

View File

@@ -5,18 +5,17 @@
The command to train a DDPM UNet model on the Oxford Flowers dataset:
```bash
python -m torch.distributed.launch \
--nproc_per_node 4 \
train_unconditional.py \
accelerate launch train_unconditional.py \
--dataset="huggan/flowers-102-categories" \
--resolution=64 \
--output_dir="flowers-ddpm" \
--batch_size=16 \
--output_dir="ddpm-ema-flowers-64" \
--train_batch_size=16 \
--num_epochs=100 \
--gradient_accumulation_steps=1 \
--lr=1e-4 \
--warmup_steps=500 \
--mixed_precision=no
--learning_rate=1e-4 \
--lr_warmup_steps=500 \
--mixed_precision=no \
--push_to_hub
```
A full training run takes 2 hours on 4xV100 GPUs.
@@ -29,18 +28,17 @@ A full training run takes 2 hours on 4xV100 GPUs.
The command to train a DDPM UNet model on the Pokemon dataset:
```bash
python -m torch.distributed.launch \
--nproc_per_node 4 \
train_unconditional.py \
accelerate launch train_unconditional.py \
--dataset="huggan/pokemon" \
--resolution=64 \
--output_dir="pokemon-ddpm" \
--batch_size=16 \
--output_dir="ddpm-ema-pokemon-64" \
--train_batch_size=16 \
--num_epochs=100 \
--gradient_accumulation_steps=1 \
--lr=1e-4 \
--warmup_steps=500 \
--mixed_precision=no
--learning_rate=1e-4 \
--lr_warmup_steps=500 \
--mixed_precision=no \
--push_to_hub
```
A full training run takes 2 hours on 4xV100 GPUs.

View File

@@ -4,10 +4,10 @@ import os
import torch
import torch.nn.functional as F
from accelerate import Accelerator, DistributedDataParallelKwargs
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDIMPipeline, DDIMScheduler, UNetModel
from diffusers import DDPMPipeline, DDPMScheduler, UNetUnconditionalModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
@@ -27,25 +27,37 @@ logger = get_logger(__name__)
def main(args):
ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with="tensorboard",
logging_dir=logging_dir,
kwargs_handlers=[ddp_unused_params],
)
model = UNetModel(
attn_resolutions=(16,),
ch=128,
ch_mult=(1, 2, 4, 8),
dropout=0.0,
model = UNetUnconditionalModel(
image_size=args.resolution,
in_channels=3,
out_channels=3,
num_res_blocks=2,
resamp_with_conv=True,
resolution=args.resolution,
block_channels=(128, 128, 256, 256, 512, 512),
down_blocks=(
"UNetResDownBlock2D",
"UNetResDownBlock2D",
"UNetResDownBlock2D",
"UNetResDownBlock2D",
"UNetResAttnDownBlock2D",
"UNetResDownBlock2D",
),
up_blocks=(
"UNetResUpBlock2D",
"UNetResAttnUpBlock2D",
"UNetResUpBlock2D",
"UNetResUpBlock2D",
"UNetResUpBlock2D",
"UNetResUpBlock2D",
),
)
noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt")
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.learning_rate,
@@ -92,19 +104,6 @@ def main(args):
run = os.path.split(__file__)[-1].split(".")[0]
accelerator.init_trackers(run)
# Train!
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() if is_distributed else 1
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
logger.info(f" Num Epochs = {args.num_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
global_step = 0
for epoch in range(args.num_epochs):
model.train()
@@ -112,45 +111,37 @@ def main(args):
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
# Sample noise that we'll add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
).long()
# add noise onto the clean images according to the noise magnitude at each timestep
# Add noise to the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps)
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(noisy_images, timesteps)["sample"]
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
ema_model.step(model)
if args.use_ema:
ema_model.step(model)
optimizer.zero_grad()
progress_bar.update(1)
progress_bar.set_postfix(
loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay
)
accelerator.log(
{
"train_loss": loss.detach().item(),
"epoch": epoch,
"ema_decay": ema_model.decay,
"step": global_step,
},
step=global_step,
)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
if args.use_ema:
logs["ema_decay"] = ema_model.decay
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
global_step += 1
progress_bar.close()
@@ -159,14 +150,14 @@ def main(args):
# Generate a sample image for visual inspection
if accelerator.is_main_process:
with torch.no_grad():
pipeline = DDIMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model),
noise_scheduler=noise_scheduler,
pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
scheduler=noise_scheduler,
)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50)
images = pipeline(generator=generator, batch_size=args.eval_batch_size)
# denormalize the images and save to tensorboard
images_processed = (images.cpu() + 1.0) * 127.5
@@ -174,11 +165,12 @@ def main(args):
accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch)
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone()
accelerator.end_training()
@@ -188,12 +180,13 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--output_dir", type=str, default="ddpm-model")
parser.add_argument("--output_dir", type=str, default="ddpm-flowers-64")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--save_model_epochs", type=int, default=5)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--lr_scheduler", type=str, default="cosine")
@@ -202,6 +195,7 @@ if __name__ == "__main__":
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
parser.add_argument("--adam_epsilon", type=float, default=1e-3)
parser.add_argument("--use_ema", action="store_true", default=True)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3 / 4)
parser.add_argument("--ema_max_decay", type=float, default=0.9999)

View File

@@ -145,7 +145,7 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)

View File

@@ -19,6 +19,7 @@ import os
from typing import Optional, Union
from huggingface_hub import snapshot_download
from PIL import Image
from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging
@@ -120,6 +121,7 @@ class DiffusionPipeline(ConfigMixin):
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
@@ -131,6 +133,7 @@ class DiffusionPipeline(ConfigMixin):
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
)
else:
cached_folder = pretrained_model_name_or_path
@@ -187,3 +190,15 @@ class DiffusionPipeline(ConfigMixin):
# 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
return model
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images

View File

@@ -28,7 +28,7 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -53,4 +53,9 @@ class DDIMPipeline(DiffusionPipeline):
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}

View File

@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None):
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -54,4 +54,9 @@ class DDPMPipeline(DiffusionPipeline):
# 3. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}

View File

@@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint
import tqdm
from tqdm.auto import tqdm
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
@@ -30,51 +30,67 @@ class LatentDiffusionPipeline(DiffusionPipeline):
eta=0.0,
guidance_scale=1.0,
num_inference_steps=50,
output_type="pil",
):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = len(prompt)
self.unet.to(torch_device)
self.vqvae.to(torch_device)
self.bert.to(torch_device)
# get unconditional embeddings for classifier free guidence
# get unconditional embeddings for classifier free guidance
if guidance_scale != 1.0:
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
torch_device
)
uncond_embeddings = self.bert(uncond_input.input_ids)
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))
# get text embedding
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
text_embedding = self.bert(text_input.input_ids)
# get prompt text embeddings
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
text_embeddings = self.bert(text_input.input_ids.to(torch_device))
image = torch.randn(
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
).to(torch_device)
)
latents = latents.to(torch_device)
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm.tqdm(self.scheduler.timesteps):
# 1. predict noise residual
pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding)
for t in tqdm(self.scheduler.timesteps):
if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance
latents_input = latents
context = text_embeddings
else:
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = torch.cat([latents] * 2)
context = torch.cat([uncond_embeddings, text_embeddings])
if isinstance(pred_noise_t, dict):
pred_noise_t = pred_noise_t["sample"]
# predict the noise residual
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
# perform guidance
if guidance_scale != 1.0:
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"]
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, eta)["prev_sample"]
# scale and decode image with vae
image = 1 / 0.18215 * image
image = self.vqvae.decode(image)
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vqvae.decode(latents)
return image
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
################################################################################

View File

@@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
@torch.no_grad()
def __call__(
self,
batch_size=1,
generator=None,
torch_device=None,
eta=0.0,
num_inference_steps=50,
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
):
# eta corresponds to η in paper and should be between [0, 1]
@@ -28,25 +23,26 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
self.unet.to(torch_device)
self.vqvae.to(torch_device)
image = torch.randn(
latents = torch.randn(
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
generator=generator,
).to(torch_device)
)
latents = latents.to(torch_device)
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps):
with torch.no_grad():
model_output = self.unet(image, t)
# predict the noise residual
noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"]
if isinstance(model_output, dict):
model_output = model_output["sample"]
# decode the image latents with the VAE
image = self.vqvae.decode(latents)
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
# decode image with vae
with torch.no_grad():
image = self.vqvae.decode(image)
return {"sample": image}

View File

@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"):
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
if torch_device is None:
@@ -43,18 +43,20 @@ class PNDMPipeline(DiffusionPipeline):
)
image = image.to(torch_device)
prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps)
for t in tqdm(range(len(prk_time_steps))):
t_orig = prk_time_steps[t]
model_output = self.unet(image, t_orig)["sample"]
self.scheduler.set_timesteps(num_inference_steps)
for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
model_output = self.unet(image, t)["sample"]
image = self.scheduler.step_prk(model_output, t, image, num_inference_steps)["prev_sample"]
image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"]
timesteps = self.scheduler.get_time_steps(num_inference_steps)
for t in tqdm(range(len(timesteps))):
t_orig = timesteps[t]
model_output = self.unet(image, t_orig)["sample"]
for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
model_output = self.unet(image, t)["sample"]
image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"]
image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}

View File

@@ -2,15 +2,16 @@
import torch
from diffusers import DiffusionPipeline
from tqdm.auto import tqdm
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class ScoreSdeVePipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
self.register_modules(model=model, scheduler=scheduler)
def __call__(self, num_inference_steps=2000, generator=None):
@torch.no_grad()
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
img_size = self.model.config.image_size
@@ -24,12 +25,11 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps)
for i, t in enumerate(self.scheduler.timesteps):
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
for _ in range(self.scheduler.correct_steps):
with torch.no_grad():
model_output = self.model(sample, sigma_t)
model_output = self.model(sample, sigma_t)
if isinstance(model_output, dict):
model_output = model_output["sample"]
@@ -45,4 +45,9 @@ class ScoreSdeVePipeline(DiffusionPipeline):
output = self.scheduler.step_pred(model_output, t, sample)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
return sample_mean
sample = sample.clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
sample = self.numpy_to_pil(sample)
return {"sample": sample}

View File

@@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
import pdb
from typing import Union
import numpy as np
@@ -71,8 +72,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format)
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2.
@@ -82,49 +81,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.cur_model_output = 0
self.cur_sample = None
self.ets = []
self.prk_time_steps = {}
self.time_steps = {}
self.set_prk_mode()
def get_prk_time_steps(self, num_inference_steps):
if num_inference_steps in self.prk_time_steps:
return self.prk_time_steps[num_inference_steps]
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
self.prk_timesteps = None
self.plms_timesteps = None
inference_step_times = list(
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps):
self.num_inference_steps = num_inference_steps
self.timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
self.plms_timesteps = list(reversed(self.timesteps[:-3]))
return self.prk_time_steps[num_inference_steps]
def get_time_steps(self, num_inference_steps):
if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps]
inference_step_times = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
return self.time_steps[num_inference_steps]
def set_prk_mode(self):
self.mode = "prk"
def set_plms_mode(self):
self.mode = "plms"
def step(self, *args, **kwargs):
if self.mode == "prk":
return self.step_prk(*args, **kwargs)
if self.mode == "plms":
return self.step_plms(*args, **kwargs)
raise ValueError(f"mode {self.mode} does not exist.")
self.set_format(tensor_format=self.tensor_format)
def step_prk(
self,
@@ -138,7 +117,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
solution to the differential equation.
"""
t = timestep
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
prk_time_steps = self.prk_timesteps
t_orig = prk_time_steps[t // 4 * 4]
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
@@ -180,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"for more information."
)
timesteps = self.get_time_steps(num_inference_steps)
timesteps = self.plms_timesteps
t_orig = timesteps[t]
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]

View File

@@ -18,11 +18,11 @@ import inspect
import math
import tempfile
import unittest
from atexit import register
import numpy as np
import torch
import PIL
from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it
from diffusers import (
AutoencoderKL,
@@ -704,11 +704,11 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"]
image = ddpm(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0)
new_image = new_ddpm(generator=generator)["sample"]
new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_from_pretrained_hub(self):
@@ -722,11 +722,32 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"]
image = ddpm(generator=generator, output_type="numpy")["sample"]
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator)["sample"]
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@slow
def test_output_format(self):
model_path = "google/ddpm-cifar10-32"
pipe = DDIMPipeline.from_pretrained(model_path)
generator = torch.manual_seed(0)
images = pipe(generator=generator, output_type="numpy")["sample"]
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
images = pipe(generator=generator, output_type="pil")["sample"]
assert isinstance(images, list)
assert len(images) == 1
assert isinstance(images[0], PIL.Image.Image)
# use PIL by default
images = pipe(generator=generator)["sample"]
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
@slow
def test_ddpm_cifar10(self):
@@ -739,15 +760,13 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"]
image = ddpm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ddim_lsun(self):
@@ -759,15 +778,13 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = ddpm(generator=generator)["sample"]
image = ddpm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor(
[-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ddim_cifar10(self):
@@ -779,15 +796,13 @@ class PipelineTesterMixin(unittest.TestCase):
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0)["sample"]
image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_pndm_cifar10(self):
@@ -798,15 +813,13 @@ class PipelineTesterMixin(unittest.TestCase):
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
generator = torch.manual_seed(0)
image = pndm(generator=generator)["sample"]
image = pndm(generator=generator, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 32, 32)
expected_slice = torch.tensor(
[-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ldm_text2img(self):
@@ -814,13 +827,15 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=20)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
"sample"
]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ldm_text2img_fast(self):
@@ -828,55 +843,44 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, num_inference_steps=1)
image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_score_sde_ve_pipeline(self):
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-ffhq-1024")
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256")
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-ffhq-1024")
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
torch.manual_seed(0)
image = sde_ve(num_inference_steps=2)
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
if model.device.type == "cpu":
# patrick's cpu
expected_image_sum = 3384805888.0
expected_image_mean = 1076.00085
image_slice = image[0, -3:, -3:, -1]
# m1 mbp
# expected_image_sum = 3384805376.0
# expected_image_mean = 1076.000610351562
else:
expected_image_sum = 3382849024.0
expected_image_mean = 1075.3788
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow
def test_ldm_uncond(self):
ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256")
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5)["sample"]
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
image_slice = image[0, -1, -3:, -3:].cpu()
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor(
[-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
)
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

View File

@@ -70,7 +70,6 @@ class SchedulerCommonTest(unittest.TestCase):
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
sample = self.dummy_sample
residual = 0.1 * sample
@@ -102,7 +101,6 @@ class SchedulerCommonTest(unittest.TestCase):
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
@@ -375,33 +373,40 @@ class PNDMSchedulerTest(SchedulerCommonTest):
config.update(**kwargs)
return config
def check_over_configs_pmls(self, time_step=0, **config):
def check_over_configs(self, time_step=0, **config):
kwargs = dict(self.forward_default_kwargs)
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(kwargs["num_inference_steps"])
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler.set_plms_mode()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode()
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_from_pretrained_save_pretrained(self):
pass
def check_over_forward(self, time_step=0, **forward_kwargs):
kwargs = dict(self.forward_default_kwargs)
kwargs.update(forward_kwargs)
sample = self.dummy_sample
@@ -409,74 +414,127 @@ class PNDMSchedulerTest(SchedulerCommonTest):
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(kwargs["num_inference_steps"])
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler.set_plms_mode()
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
new_scheduler.set_plms_mode()
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
def test_pytorch_equal_numpy(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
sample_pt = torch.tensor(sample)
residual_pt = 0.1 * sample_pt
dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
# copy over dummy past residuals
scheduler.ets = dummy_past_residuals[:]
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
# copy over dummy past residuals
scheduler_pt.ets = dummy_past_residuals_pt[:]
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
scheduler_pt.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_step_shape(self):
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
# copy over dummy past residuals
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler.ets = dummy_past_residuals[:]
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
def test_timesteps(self):
for timesteps in [100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_timesteps_pmls(self):
for timesteps in [100, 1000]:
self.check_over_configs_pmls(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_betas_pmls(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_schedules_pmls(self):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_time_indices(self):
for t in [1, 5, 10]:
self.check_over_forward(time_step=t)
def test_time_indices_pmls(self):
for t in [1, 5, 10]:
self.check_over_forward_pmls(time_step=t)
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_inference_steps_pmls(self):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)
def test_inference_pmls_no_past_residuals(self):
def test_inference_plms_no_past_residuals(self):
with self.assertRaises(ValueError):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_plms_mode()
scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
@@ -486,20 +544,15 @@ class PNDMSchedulerTest(SchedulerCommonTest):
num_inference_steps = 10
model = self.dummy_model()
sample = self.dummy_sample_deter
scheduler.set_timesteps(num_inference_steps)
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
for t in range(len(prk_time_steps)):
t_orig = prk_time_steps[t]
residual = model(sample, t_orig)
for i, t in enumerate(scheduler.prk_timesteps):
residual = model(sample, t)
sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"]
sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"]
timesteps = scheduler.get_time_steps(num_inference_steps)
for t in range(len(timesteps)):
t_orig = timesteps[t]
residual = model(sample, t_orig)
sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"]
for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t)
sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"]
result_sum = np.sum(np.abs(sample))
result_mean = np.mean(np.abs(sample))
@@ -562,7 +615,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
kwargs = dict(self.forward_default_kwargs)
for scheduler_class in self.scheduler_classes:
scheduler_class = self.scheduler_classes[0]
sample = self.dummy_sample
residual = 0.1 * sample
@@ -591,7 +643,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)