diff --git a/examples/README.md b/examples/README.md index d806e852e9..c09baa8ead 100644 --- a/examples/README.md +++ b/examples/README.md @@ -5,18 +5,17 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset: ```bash -python -m torch.distributed.launch \ - --nproc_per_node 4 \ - train_unconditional.py \ +accelerate launch train_unconditional.py \ --dataset="huggan/flowers-102-categories" \ --resolution=64 \ - --output_dir="flowers-ddpm" \ - --batch_size=16 \ + --output_dir="ddpm-ema-flowers-64" \ + --train_batch_size=16 \ --num_epochs=100 \ --gradient_accumulation_steps=1 \ - --lr=1e-4 \ - --warmup_steps=500 \ - --mixed_precision=no + --learning_rate=1e-4 \ + --lr_warmup_steps=500 \ + --mixed_precision=no \ + --push_to_hub ``` A full training run takes 2 hours on 4xV100 GPUs. @@ -29,18 +28,17 @@ A full training run takes 2 hours on 4xV100 GPUs. The command to train a DDPM UNet model on the Pokemon dataset: ```bash -python -m torch.distributed.launch \ - --nproc_per_node 4 \ - train_unconditional.py \ +accelerate launch train_unconditional.py \ --dataset="huggan/pokemon" \ --resolution=64 \ - --output_dir="pokemon-ddpm" \ - --batch_size=16 \ + --output_dir="ddpm-ema-pokemon-64" \ + --train_batch_size=16 \ --num_epochs=100 \ --gradient_accumulation_steps=1 \ - --lr=1e-4 \ - --warmup_steps=500 \ - --mixed_precision=no + --learning_rate=1e-4 \ + --lr_warmup_steps=500 \ + --mixed_precision=no \ + --push_to_hub ``` A full training run takes 2 hours on 4xV100 GPUs. diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index ebe5eb9826..787cdbb288 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -4,10 +4,10 @@ import os import torch import torch.nn.functional as F -from accelerate import Accelerator, DistributedDataParallelKwargs +from accelerate import Accelerator from accelerate.logging import get_logger from datasets import load_dataset -from diffusers import DDIMPipeline, DDIMScheduler, UNetModel +from diffusers import DDPMPipeline, DDPMScheduler, UNetUnconditionalModel from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel @@ -27,25 +27,37 @@ logger = get_logger(__name__) def main(args): - ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True) logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator = Accelerator( mixed_precision=args.mixed_precision, log_with="tensorboard", logging_dir=logging_dir, - kwargs_handlers=[ddp_unused_params], ) - model = UNetModel( - attn_resolutions=(16,), - ch=128, - ch_mult=(1, 2, 4, 8), - dropout=0.0, + model = UNetUnconditionalModel( + image_size=args.resolution, + in_channels=3, + out_channels=3, num_res_blocks=2, - resamp_with_conv=True, - resolution=args.resolution, + block_channels=(128, 128, 256, 256, 512, 512), + down_blocks=( + "UNetResDownBlock2D", + "UNetResDownBlock2D", + "UNetResDownBlock2D", + "UNetResDownBlock2D", + "UNetResAttnDownBlock2D", + "UNetResDownBlock2D", + ), + up_blocks=( + "UNetResUpBlock2D", + "UNetResAttnUpBlock2D", + "UNetResUpBlock2D", + "UNetResUpBlock2D", + "UNetResUpBlock2D", + "UNetResUpBlock2D", + ), ) - noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt") + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt") optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, @@ -92,19 +104,6 @@ def main(args): run = os.path.split(__file__)[-1].split(".")[0] accelerator.init_trackers(run) - # Train! - is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() - world_size = torch.distributed.get_world_size() if is_distributed else 1 - total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size - max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataloader.dataset)}") - logger.info(f" Num Epochs = {args.num_epochs}") - logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {max_steps}") - global_step = 0 for epoch in range(args.num_epochs): model.train() @@ -112,45 +111,37 @@ def main(args): progress_bar.set_description(f"Epoch {epoch}") for step, batch in enumerate(train_dataloader): clean_images = batch["input"] - noise_samples = torch.randn(clean_images.shape).to(clean_images.device) + # Sample noise that we'll add to the images + noise = torch.randn(clean_images.shape).to(clean_images.device) bsz = clean_images.shape[0] - timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device + ).long() - # add noise onto the clean images according to the noise magnitude at each timestep + # Add noise to the clean images according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps) + noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) - if step % args.gradient_accumulation_steps != 0: - with accelerator.no_sync(model): - output = model(noisy_images, timesteps) - # predict the noise residual - loss = F.mse_loss(output, noise_samples) - loss = loss / args.gradient_accumulation_steps - accelerator.backward(loss) - else: - output = model(noisy_images, timesteps) - # predict the noise residual - loss = F.mse_loss(output, noise_samples) - loss = loss / args.gradient_accumulation_steps + with accelerator.accumulate(model): + # Predict the noise residual + noise_pred = model(noisy_images, timesteps)["sample"] + loss = F.mse_loss(noise_pred, noise) accelerator.backward(loss) - torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() - ema_model.step(model) + if args.use_ema: + ema_model.step(model) optimizer.zero_grad() + progress_bar.update(1) - progress_bar.set_postfix( - loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay - ) - accelerator.log( - { - "train_loss": loss.detach().item(), - "epoch": epoch, - "ema_decay": ema_model.decay, - "step": global_step, - }, - step=global_step, - ) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} + if args.use_ema: + logs["ema_decay"] = ema_model.decay + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) global_step += 1 progress_bar.close() @@ -159,14 +150,14 @@ def main(args): # Generate a sample image for visual inspection if accelerator.is_main_process: with torch.no_grad(): - pipeline = DDIMPipeline( - unet=accelerator.unwrap_model(ema_model.averaged_model), - noise_scheduler=noise_scheduler, + pipeline = DDPMPipeline( + unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), + scheduler=noise_scheduler, ) generator = torch.manual_seed(0) # run pipeline in inference (sample random noise and denoise) - images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50) + images = pipeline(generator=generator, batch_size=args.eval_batch_size) # denormalize the images and save to tensorboard images_processed = (images.cpu() + 1.0) * 127.5 @@ -174,11 +165,12 @@ def main(args): accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch) - # save the model - if args.push_to_hub: - push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) - else: - pipeline.save_pretrained(args.output_dir) + if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: + # save the model + if args.push_to_hub: + push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) + else: + pipeline.save_pretrained(args.output_dir) accelerator.wait_for_everyone() accelerator.end_training() @@ -188,12 +180,13 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument("--local_rank", type=int, default=-1) parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories") - parser.add_argument("--output_dir", type=str, default="ddpm-model") + parser.add_argument("--output_dir", type=str, default="ddpm-flowers-64") parser.add_argument("--overwrite_output_dir", action="store_true") parser.add_argument("--resolution", type=int, default=64) parser.add_argument("--train_batch_size", type=int, default=16) parser.add_argument("--eval_batch_size", type=int, default=16) parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--save_model_epochs", type=int, default=5) parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--lr_scheduler", type=str, default="cosine") @@ -202,6 +195,7 @@ if __name__ == "__main__": parser.add_argument("--adam_beta2", type=float, default=0.999) parser.add_argument("--adam_weight_decay", type=float, default=1e-6) parser.add_argument("--adam_epsilon", type=float, default=1e-3) + parser.add_argument("--use_ema", action="store_true", default=True) parser.add_argument("--ema_inv_gamma", type=float, default=1.0) parser.add_argument("--ema_power", type=float, default=3 / 4) parser.add_argument("--ema_max_decay", type=float, default=0.9999) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 95b22fcd59..d16ab792f5 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -145,7 +145,7 @@ class Decoder(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) + # print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape))) # z to block_in self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index b0ff25c339..0fa6852bd1 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -19,6 +19,7 @@ import os from typing import Optional, Union from huggingface_hub import snapshot_download +from PIL import Image from .configuration_utils import ConfigMixin from .utils import DIFFUSERS_CACHE, logging @@ -120,6 +121,7 @@ class DiffusionPipeline(ConfigMixin): proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained @@ -131,6 +133,7 @@ class DiffusionPipeline(ConfigMixin): proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, + revision=revision, ) else: cached_folder = pretrained_model_name_or_path @@ -187,3 +190,15 @@ class DiffusionPipeline(ConfigMixin): # 5. Instantiate the pipeline model = pipeline_class(**init_kwargs) return model + + @staticmethod + def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + pil_images = [Image.fromarray(image) for image in images] + + return pil_images diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 8dba08f728..5f9227c9cb 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -28,7 +28,7 @@ class DDIMPipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50): + def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"): # eta corresponds to η in paper and should be between [0, 1] if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -53,4 +53,9 @@ class DDIMPipeline(DiffusionPipeline): # do x_t -> x_t-1 image = self.scheduler.step(model_output, t, image, eta)["prev_sample"] + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + return {"sample": image} diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index e72b05cf89..a7309224ef 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, generator=None, torch_device=None): + def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"): if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -54,4 +54,9 @@ class DDPMPipeline(DiffusionPipeline): # 3. set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + return {"sample": image} diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index fb85608ae6..5b3c5dc8cb 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.utils.checkpoint -import tqdm +from tqdm.auto import tqdm from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import BaseModelOutput @@ -30,51 +30,67 @@ class LatentDiffusionPipeline(DiffusionPipeline): eta=0.0, guidance_scale=1.0, num_inference_steps=50, + output_type="pil", ): # eta corresponds to η in paper and should be between [0, 1] if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" + batch_size = len(prompt) self.unet.to(torch_device) self.vqvae.to(torch_device) self.bert.to(torch_device) - # get unconditional embeddings for classifier free guidence + # get unconditional embeddings for classifier free guidance if guidance_scale != 1.0: - uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to( - torch_device - ) - uncond_embeddings = self.bert(uncond_input.input_ids) + uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") + uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device)) - # get text embedding - text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device) - text_embedding = self.bert(text_input.input_ids) + # get prompt text embeddings + text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") + text_embeddings = self.bert(text_input.input_ids.to(torch_device)) - image = torch.randn( + latents = torch.randn( (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), generator=generator, - ).to(torch_device) + ) + latents = latents.to(torch_device) self.scheduler.set_timesteps(num_inference_steps) - for t in tqdm.tqdm(self.scheduler.timesteps): - # 1. predict noise residual - pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding) + for t in tqdm(self.scheduler.timesteps): + if guidance_scale == 1.0: + # guidance_scale of 1 means no guidance + latents_input = latents + context = text_embeddings + else: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = torch.cat([latents] * 2) + context = torch.cat([uncond_embeddings, text_embeddings]) - if isinstance(pred_noise_t, dict): - pred_noise_t = pred_noise_t["sample"] + # predict the noise residual + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"] + # perform guidance + if guidance_scale != 1.0: + noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) - # 2. predict previous mean of image x_t-1 and add variance depending on eta - # do x_t -> x_t-1 - image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"] + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, eta)["prev_sample"] - # scale and decode image with vae - image = 1 / 0.18215 * image - image = self.vqvae.decode(image) - image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + # scale and decode the image latents with vae + latents = 1 / 0.18215 * latents + image = self.vqvae.decode(latents) - return image + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) + + return {"sample": image} ################################################################################ diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index 79cf799f5b..0964225e8b 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): @torch.no_grad() def __call__( - self, - batch_size=1, - generator=None, - torch_device=None, - eta=0.0, - num_inference_steps=50, + self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil" ): # eta corresponds to η in paper and should be between [0, 1] @@ -28,25 +23,26 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline): self.unet.to(torch_device) self.vqvae.to(torch_device) - image = torch.randn( + latents = torch.randn( (batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size), generator=generator, - ).to(torch_device) + ) + latents = latents.to(torch_device) self.scheduler.set_timesteps(num_inference_steps) for t in tqdm(self.scheduler.timesteps): - with torch.no_grad(): - model_output = self.unet(image, t) + # predict the noise residual + noise_prediction = self.unet(latents, t)["sample"] + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"] - if isinstance(model_output, dict): - model_output = model_output["sample"] + # decode the image latents with the VAE + image = self.vqvae.decode(latents) - # 2. predict previous mean of image x_t-1 and add variance depending on eta - # do x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, eta)["prev_sample"] + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) - # decode image with vae - with torch.no_grad(): - image = self.vqvae.decode(image) return {"sample": image} diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index e7d4eb8ab5..33ec1a3e98 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50): + def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"): # For more information on the sampling method you can take a look at Algorithm 2 of # the official paper: https://arxiv.org/pdf/2202.09778.pdf if torch_device is None: @@ -43,18 +43,20 @@ class PNDMPipeline(DiffusionPipeline): ) image = image.to(torch_device) - prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps) - for t in tqdm(range(len(prk_time_steps))): - t_orig = prk_time_steps[t] - model_output = self.unet(image, t_orig)["sample"] + self.scheduler.set_timesteps(num_inference_steps) + for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)): + model_output = self.unet(image, t)["sample"] - image = self.scheduler.step_prk(model_output, t, image, num_inference_steps)["prev_sample"] + image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"] - timesteps = self.scheduler.get_time_steps(num_inference_steps) - for t in tqdm(range(len(timesteps))): - t_orig = timesteps[t] - model_output = self.unet(image, t_orig)["sample"] + for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)): + model_output = self.unet(image, t)["sample"] - image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"] + image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"] + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + image = self.numpy_to_pil(image) return {"sample": image} diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 80b0f3ef9a..5b3be8b66f 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -2,15 +2,16 @@ import torch from diffusers import DiffusionPipeline +from tqdm.auto import tqdm -# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names class ScoreSdeVePipeline(DiffusionPipeline): def __init__(self, model, scheduler): super().__init__() self.register_modules(model=model, scheduler=scheduler) - def __call__(self, num_inference_steps=2000, generator=None): + @torch.no_grad() + def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") img_size = self.model.config.image_size @@ -24,12 +25,11 @@ class ScoreSdeVePipeline(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps) - for i, t in enumerate(self.scheduler.timesteps): + for i, t in tqdm(enumerate(self.scheduler.timesteps)): sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device) for _ in range(self.scheduler.correct_steps): - with torch.no_grad(): - model_output = self.model(sample, sigma_t) + model_output = self.model(sample, sigma_t) if isinstance(model_output, dict): model_output = model_output["sample"] @@ -45,4 +45,9 @@ class ScoreSdeVePipeline(DiffusionPipeline): output = self.scheduler.step_pred(model_output, t, sample) sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] - return sample_mean + sample = sample.clamp(0, 1) + sample = sample.cpu().permute(0, 2, 3, 1).numpy() + if output_type == "pil": + sample = self.numpy_to_pil(sample) + + return {"sample": sample} diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 216c4a715f..2c157e05d3 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,6 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math +import pdb from typing import Union import numpy as np @@ -71,8 +72,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.one = np.array(1.0) - self.set_format(tensor_format=tensor_format) - # For now we only support F-PNDM, i.e. the runge-kutta method # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf # mainly at formula (9), (12), (13) and the Algorithm 2. @@ -82,49 +81,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): self.cur_model_output = 0 self.cur_sample = None self.ets = [] - self.prk_time_steps = {} - self.time_steps = {} - self.set_prk_mode() - def get_prk_time_steps(self, num_inference_steps): - if num_inference_steps in self.prk_time_steps: - return self.prk_time_steps[num_inference_steps] + # setable values + self.num_inference_steps = None + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + self.prk_timesteps = None + self.plms_timesteps = None - inference_step_times = list( + self.tensor_format = tensor_format + self.set_format(tensor_format=tensor_format) + + def set_timesteps(self, num_inference_steps): + self.num_inference_steps = num_inference_steps + self.timesteps = list( range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) ) - prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile( + prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile( np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order ) - self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1])) + self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1])) + self.plms_timesteps = list(reversed(self.timesteps[:-3])) - return self.prk_time_steps[num_inference_steps] - - def get_time_steps(self, num_inference_steps): - if num_inference_steps in self.time_steps: - return self.time_steps[num_inference_steps] - - inference_step_times = list( - range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps) - ) - self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3])) - - return self.time_steps[num_inference_steps] - - def set_prk_mode(self): - self.mode = "prk" - - def set_plms_mode(self): - self.mode = "plms" - - def step(self, *args, **kwargs): - if self.mode == "prk": - return self.step_prk(*args, **kwargs) - if self.mode == "plms": - return self.step_plms(*args, **kwargs) - - raise ValueError(f"mode {self.mode} does not exist.") + self.set_format(tensor_format=self.tensor_format) def step_prk( self, @@ -138,7 +117,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): solution to the differential equation. """ t = timestep - prk_time_steps = self.get_prk_time_steps(num_inference_steps) + prk_time_steps = self.prk_timesteps t_orig = prk_time_steps[t // 4 * 4] t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)] @@ -180,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): "for more information." ) - timesteps = self.get_time_steps(num_inference_steps) + timesteps = self.plms_timesteps t_orig = timesteps[t] t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)] diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index dc7f125476..6b8b17128d 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -18,11 +18,11 @@ import inspect import math import tempfile import unittest -from atexit import register import numpy as np import torch +import PIL from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it from diffusers import ( AutoencoderKL, @@ -704,11 +704,11 @@ class PipelineTesterMixin(unittest.TestCase): generator = torch.manual_seed(0) - image = ddpm(generator=generator)["sample"] + image = ddpm(generator=generator, output_type="numpy")["sample"] generator = generator.manual_seed(0) - new_image = new_ddpm(generator=generator)["sample"] + new_image = new_ddpm(generator=generator, output_type="numpy")["sample"] - assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" @slow def test_from_pretrained_hub(self): @@ -722,11 +722,32 @@ class PipelineTesterMixin(unittest.TestCase): generator = torch.manual_seed(0) - image = ddpm(generator=generator)["sample"] + image = ddpm(generator=generator, output_type="numpy")["sample"] generator = generator.manual_seed(0) - new_image = ddpm_from_hub(generator=generator)["sample"] + new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"] - assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass" + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + + @slow + def test_output_format(self): + model_path = "google/ddpm-cifar10-32" + + pipe = DDIMPipeline.from_pretrained(model_path) + + generator = torch.manual_seed(0) + images = pipe(generator=generator, output_type="numpy")["sample"] + assert images.shape == (1, 32, 32, 3) + assert isinstance(images, np.ndarray) + + images = pipe(generator=generator, output_type="pil")["sample"] + assert isinstance(images, list) + assert len(images) == 1 + assert isinstance(images[0], PIL.Image.Image) + + # use PIL by default + images = pipe(generator=generator)["sample"] + assert isinstance(images, list) + assert isinstance(images[0], PIL.Image.Image) @slow def test_ddpm_cifar10(self): @@ -739,15 +760,13 @@ class PipelineTesterMixin(unittest.TestCase): ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) generator = torch.manual_seed(0) - image = ddpm(generator=generator)["sample"] + image = ddpm(generator=generator, output_type="numpy")["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ddim_lsun(self): @@ -759,15 +778,13 @@ class PipelineTesterMixin(unittest.TestCase): ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) generator = torch.manual_seed(0) - image = ddpm(generator=generator)["sample"] + image = ddpm(generator=generator, output_type="numpy")["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor( - [-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ddim_cifar10(self): @@ -779,15 +796,13 @@ class PipelineTesterMixin(unittest.TestCase): ddim = DDIMPipeline(unet=unet, scheduler=scheduler) generator = torch.manual_seed(0) - image = ddim(generator=generator, eta=0.0)["sample"] + image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_pndm_cifar10(self): @@ -798,15 +813,13 @@ class PipelineTesterMixin(unittest.TestCase): pndm = PNDMPipeline(unet=unet, scheduler=scheduler) generator = torch.manual_seed(0) - image = pndm(generator=generator)["sample"] + image = pndm(generator=generator, output_type="numpy")["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 32, 32) - expected_slice = torch.tensor( - [-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ldm_text2img(self): @@ -814,13 +827,15 @@ class PipelineTesterMixin(unittest.TestCase): prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, num_inference_steps=20) + image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[ + "sample" + ] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ldm_text2img_fast(self): @@ -828,55 +843,44 @@ class PipelineTesterMixin(unittest.TestCase): prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) - image = ldm([prompt], generator=generator, num_inference_steps=1) + image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_score_sde_ve_pipeline(self): - model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-ffhq-1024") + model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256") torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) - scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-ffhq-1024") + scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256") sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler) torch.manual_seed(0) - image = sde_ve(num_inference_steps=2) + image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"] - if model.device.type == "cpu": - # patrick's cpu - expected_image_sum = 3384805888.0 - expected_image_mean = 1076.00085 + image_slice = image[0, -3:, -3:, -1] - # m1 mbp - # expected_image_sum = 3384805376.0 - # expected_image_mean = 1076.000610351562 - else: - expected_image_sum = 3382849024.0 - expected_image_mean = 1075.3788 - - assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2 - assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4 + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow def test_ldm_uncond(self): ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256") generator = torch.manual_seed(0) - image = ldm(generator=generator, num_inference_steps=5)["sample"] + image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"] - image_slice = image[0, -1, -3:, -3:].cpu() + image_slice = image[0, -3:, -3:, -1] - assert image.shape == (1, 3, 256, 256) - expected_slice = torch.tensor( - [-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106] - ) - assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index a409426a64..3059da1661 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -70,7 +70,6 @@ class SchedulerCommonTest(unittest.TestCase): num_inference_steps = kwargs.pop("num_inference_steps", None) for scheduler_class in self.scheduler_classes: - scheduler_class = self.scheduler_classes[0] sample = self.dummy_sample residual = 0.1 * sample @@ -102,7 +101,6 @@ class SchedulerCommonTest(unittest.TestCase): sample = self.dummy_sample residual = 0.1 * sample - scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) @@ -375,33 +373,40 @@ class PNDMSchedulerTest(SchedulerCommonTest): config.update(**kwargs) return config - def check_over_configs_pmls(self, time_step=0, **config): + def check_over_configs(self, time_step=0, **config): kwargs = dict(self.forward_default_kwargs) sample = self.dummy_sample residual = 0.1 * sample dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] for scheduler_class in self.scheduler_classes: - scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config(**config) scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(kwargs["num_inference_steps"]) # copy over dummy past residuals scheduler.ets = dummy_past_residuals[:] - scheduler.set_plms_mode() with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler.set_timesteps(kwargs["num_inference_steps"]) # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] - new_scheduler.set_plms_mode() - output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - def check_over_forward_pmls(self, time_step=0, **forward_kwargs): + output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] + + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_from_pretrained_save_pretrained(self): + pass + + def check_over_forward(self, time_step=0, **forward_kwargs): kwargs = dict(self.forward_default_kwargs) kwargs.update(forward_kwargs) sample = self.dummy_sample @@ -409,74 +414,127 @@ class PNDMSchedulerTest(SchedulerCommonTest): dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] for scheduler_class in self.scheduler_classes: - scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) + scheduler.set_timesteps(kwargs["num_inference_steps"]) + # copy over dummy past residuals scheduler.ets = dummy_past_residuals[:] - scheduler.set_plms_mode() with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] - new_scheduler.set_plms_mode() + new_scheduler.set_timesteps(kwargs["num_inference_steps"]) - output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] + + assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + + def test_pytorch_equal_numpy(self): + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + sample = self.dummy_sample + residual = 0.1 * sample + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + + sample_pt = torch.tensor(sample) + residual_pt = 0.1 * sample_pt + dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05] + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + # copy over dummy past residuals + scheduler.ets = dummy_past_residuals[:] + + scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config) + # copy over dummy past residuals + scheduler_pt.ets = dummy_past_residuals_pt[:] + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + scheduler_pt.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] + output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"] + + assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + + output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] + output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"] + + assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + + def test_step_shape(self): + kwargs = dict(self.forward_default_kwargs) + + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + # copy over dummy past residuals + dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] + scheduler.ets = dummy_past_residuals[:] + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"] + output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + + output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"] + output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"] + + self.assertEqual(output_0.shape, sample.shape) + self.assertEqual(output_0.shape, output_1.shape) + def test_timesteps(self): for timesteps in [100, 1000]: self.check_over_configs(num_train_timesteps=timesteps) - def test_timesteps_pmls(self): - for timesteps in [100, 1000]: - self.check_over_configs_pmls(num_train_timesteps=timesteps) - def test_betas(self): for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): self.check_over_configs(beta_start=beta_start, beta_end=beta_end) - def test_betas_pmls(self): - for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]): - self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end) - def test_schedules(self): for schedule in ["linear", "squaredcos_cap_v2"]: self.check_over_configs(beta_schedule=schedule) - def test_schedules_pmls(self): - for schedule in ["linear", "squaredcos_cap_v2"]: - self.check_over_configs(beta_schedule=schedule) - def test_time_indices(self): for t in [1, 5, 10]: self.check_over_forward(time_step=t) - def test_time_indices_pmls(self): - for t in [1, 5, 10]: - self.check_over_forward_pmls(time_step=t) - def test_inference_steps(self): for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) - def test_inference_steps_pmls(self): - for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]): - self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps) - - def test_inference_pmls_no_past_residuals(self): + def test_inference_plms_no_past_residuals(self): with self.assertRaises(ValueError): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - scheduler.set_plms_mode() - - scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"] + scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"] def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] @@ -486,20 +544,15 @@ class PNDMSchedulerTest(SchedulerCommonTest): num_inference_steps = 10 model = self.dummy_model() sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) - prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps) - for t in range(len(prk_time_steps)): - t_orig = prk_time_steps[t] - residual = model(sample, t_orig) + for i, t in enumerate(scheduler.prk_timesteps): + residual = model(sample, t) + sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"] - sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"] - - timesteps = scheduler.get_time_steps(num_inference_steps) - for t in range(len(timesteps)): - t_orig = timesteps[t] - residual = model(sample, t_orig) - - sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"] + for i, t in enumerate(scheduler.plms_timesteps): + residual = model(sample, t) + sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"] result_sum = np.sum(np.abs(sample)) result_mean = np.mean(np.abs(sample)) @@ -562,7 +615,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): kwargs = dict(self.forward_default_kwargs) for scheduler_class in self.scheduler_classes: - scheduler_class = self.scheduler_classes[0] sample = self.dummy_sample residual = 0.1 * sample @@ -591,7 +643,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): sample = self.dummy_sample residual = 0.1 * sample - scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config)