mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' of https://github.com/huggingface/diffusers into main
This commit is contained in:
@@ -5,18 +5,17 @@
|
||||
The command to train a DDPM UNet model on the Oxford Flowers dataset:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node 4 \
|
||||
train_unconditional.py \
|
||||
accelerate launch train_unconditional.py \
|
||||
--dataset="huggan/flowers-102-categories" \
|
||||
--resolution=64 \
|
||||
--output_dir="flowers-ddpm" \
|
||||
--batch_size=16 \
|
||||
--output_dir="ddpm-ema-flowers-64" \
|
||||
--train_batch_size=16 \
|
||||
--num_epochs=100 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--lr=1e-4 \
|
||||
--warmup_steps=500 \
|
||||
--mixed_precision=no
|
||||
--learning_rate=1e-4 \
|
||||
--lr_warmup_steps=500 \
|
||||
--mixed_precision=no \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
A full training run takes 2 hours on 4xV100 GPUs.
|
||||
@@ -29,18 +28,17 @@ A full training run takes 2 hours on 4xV100 GPUs.
|
||||
The command to train a DDPM UNet model on the Pokemon dataset:
|
||||
|
||||
```bash
|
||||
python -m torch.distributed.launch \
|
||||
--nproc_per_node 4 \
|
||||
train_unconditional.py \
|
||||
accelerate launch train_unconditional.py \
|
||||
--dataset="huggan/pokemon" \
|
||||
--resolution=64 \
|
||||
--output_dir="pokemon-ddpm" \
|
||||
--batch_size=16 \
|
||||
--output_dir="ddpm-ema-pokemon-64" \
|
||||
--train_batch_size=16 \
|
||||
--num_epochs=100 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--lr=1e-4 \
|
||||
--warmup_steps=500 \
|
||||
--mixed_precision=no
|
||||
--learning_rate=1e-4 \
|
||||
--lr_warmup_steps=500 \
|
||||
--mixed_precision=no \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
A full training run takes 2 hours on 4xV100 GPUs.
|
||||
|
||||
@@ -4,10 +4,10 @@ import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDIMPipeline, DDIMScheduler, UNetModel
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNetUnconditionalModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
@@ -27,25 +27,37 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def main(args):
|
||||
ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
logging_dir=logging_dir,
|
||||
kwargs_handlers=[ddp_unused_params],
|
||||
)
|
||||
|
||||
model = UNetModel(
|
||||
attn_resolutions=(16,),
|
||||
ch=128,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
dropout=0.0,
|
||||
model = UNetUnconditionalModel(
|
||||
image_size=args.resolution,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
num_res_blocks=2,
|
||||
resamp_with_conv=True,
|
||||
resolution=args.resolution,
|
||||
block_channels=(128, 128, 256, 256, 512, 512),
|
||||
down_blocks=(
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
"UNetResAttnDownBlock2D",
|
||||
"UNetResDownBlock2D",
|
||||
),
|
||||
up_blocks=(
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResAttnUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
"UNetResUpBlock2D",
|
||||
),
|
||||
)
|
||||
noise_scheduler = DDIMScheduler(timesteps=1000, tensor_format="pt")
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, tensor_format="pt")
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=args.learning_rate,
|
||||
@@ -92,19 +104,6 @@ def main(args):
|
||||
run = os.path.split(__file__)[-1].split(".")[0]
|
||||
accelerator.init_trackers(run)
|
||||
|
||||
# Train!
|
||||
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
world_size = torch.distributed.get_world_size() if is_distributed else 1
|
||||
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
|
||||
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(args.num_epochs):
|
||||
model.train()
|
||||
@@ -112,45 +111,37 @@ def main(args):
|
||||
progress_bar.set_description(f"Epoch {epoch}")
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
clean_images = batch["input"]
|
||||
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
# Sample noise that we'll add to the images
|
||||
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
bsz = clean_images.shape[0]
|
||||
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.num_train_timesteps, (bsz,), device=clean_images.device
|
||||
).long()
|
||||
|
||||
# add noise onto the clean images according to the noise magnitude at each timestep
|
||||
# Add noise to the clean images according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_images = noise_scheduler.add_noise(clean_images, noise_samples, timesteps)
|
||||
noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
|
||||
|
||||
if step % args.gradient_accumulation_steps != 0:
|
||||
with accelerator.no_sync(model):
|
||||
output = model(noisy_images, timesteps)
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(output, noise_samples)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
accelerator.backward(loss)
|
||||
else:
|
||||
output = model(noisy_images, timesteps)
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(output, noise_samples)
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
with accelerator.accumulate(model):
|
||||
# Predict the noise residual
|
||||
noise_pred = model(noisy_images, timesteps)["sample"]
|
||||
loss = F.mse_loss(noise_pred, noise)
|
||||
accelerator.backward(loss)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
ema_model.step(model)
|
||||
if args.use_ema:
|
||||
ema_model.step(model)
|
||||
optimizer.zero_grad()
|
||||
|
||||
progress_bar.update(1)
|
||||
progress_bar.set_postfix(
|
||||
loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"], ema_decay=ema_model.decay
|
||||
)
|
||||
accelerator.log(
|
||||
{
|
||||
"train_loss": loss.detach().item(),
|
||||
"epoch": epoch,
|
||||
"ema_decay": ema_model.decay,
|
||||
"step": global_step,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
|
||||
if args.use_ema:
|
||||
logs["ema_decay"] = ema_model.decay
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
global_step += 1
|
||||
progress_bar.close()
|
||||
|
||||
@@ -159,14 +150,14 @@ def main(args):
|
||||
# Generate a sample image for visual inspection
|
||||
if accelerator.is_main_process:
|
||||
with torch.no_grad():
|
||||
pipeline = DDIMPipeline(
|
||||
unet=accelerator.unwrap_model(ema_model.averaged_model),
|
||||
noise_scheduler=noise_scheduler,
|
||||
pipeline = DDPMPipeline(
|
||||
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
|
||||
scheduler=noise_scheduler,
|
||||
)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size, num_inference_steps=50)
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size)
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images.cpu() + 1.0) * 127.5
|
||||
@@ -174,11 +165,12 @@ def main(args):
|
||||
|
||||
accelerator.trackers[0].writer.add_images("test_samples", images_processed, epoch)
|
||||
|
||||
# save the model
|
||||
if args.push_to_hub:
|
||||
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
|
||||
else:
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
if args.push_to_hub:
|
||||
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
|
||||
else:
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
accelerator.end_training()
|
||||
@@ -188,12 +180,13 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
|
||||
parser.add_argument("--output_dir", type=str, default="ddpm-model")
|
||||
parser.add_argument("--output_dir", type=str, default="ddpm-flowers-64")
|
||||
parser.add_argument("--overwrite_output_dir", action="store_true")
|
||||
parser.add_argument("--resolution", type=int, default=64)
|
||||
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||
parser.add_argument("--eval_batch_size", type=int, default=16)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--save_model_epochs", type=int, default=5)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--lr_scheduler", type=str, default="cosine")
|
||||
@@ -202,6 +195,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-6)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-3)
|
||||
parser.add_argument("--use_ema", action="store_true", default=True)
|
||||
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
||||
parser.add_argument("--ema_power", type=float, default=3 / 4)
|
||||
parser.add_argument("--ema_max_decay", type=float, default=0.9999)
|
||||
|
||||
@@ -145,7 +145,7 @@ class Decoder(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
# print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
@@ -19,6 +19,7 @@ import os
|
||||
from typing import Optional, Union
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import DIFFUSERS_CACHE, logging
|
||||
@@ -120,6 +121,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
@@ -131,6 +133,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
@@ -187,3 +190,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# 5. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def numpy_to_pil(images):
|
||||
"""
|
||||
Convert a numpy image or a batch of images to a PIL image.
|
||||
"""
|
||||
if images.ndim == 3:
|
||||
images = images[None, ...]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
@@ -28,7 +28,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50):
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -53,4 +53,9 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@@ -28,7 +28,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None):
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, output_type="pil"):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@@ -54,4 +54,9 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
# 3. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
@@ -30,51 +30,67 @@ class LatentDiffusionPipeline(DiffusionPipeline):
|
||||
eta=0.0,
|
||||
guidance_scale=1.0,
|
||||
num_inference_steps=50,
|
||||
output_type="pil",
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
batch_size = len(prompt)
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
self.bert.to(torch_device)
|
||||
|
||||
# get unconditional embeddings for classifier free guidence
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if guidance_scale != 1.0:
|
||||
uncond_input = self.tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(
|
||||
torch_device
|
||||
)
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids)
|
||||
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
|
||||
uncond_embeddings = self.bert(uncond_input.input_ids.to(torch_device))
|
||||
|
||||
# get text embedding
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt").to(torch_device)
|
||||
text_embedding = self.bert(text_input.input_ids)
|
||||
# get prompt text embeddings
|
||||
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
|
||||
text_embeddings = self.bert(text_input.input_ids.to(torch_device))
|
||||
|
||||
image = torch.randn(
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
generator=generator,
|
||||
).to(torch_device)
|
||||
)
|
||||
latents = latents.to(torch_device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm.tqdm(self.scheduler.timesteps):
|
||||
# 1. predict noise residual
|
||||
pred_noise_t = self.unet(image, t, encoder_hidden_states=text_embedding)
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
if guidance_scale == 1.0:
|
||||
# guidance_scale of 1 means no guidance
|
||||
latents_input = latents
|
||||
context = text_embeddings
|
||||
else:
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
latents_input = torch.cat([latents] * 2)
|
||||
context = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
if isinstance(pred_noise_t, dict):
|
||||
pred_noise_t = pred_noise_t["sample"]
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
|
||||
# perform guidance
|
||||
if guidance_scale != 1.0:
|
||||
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(pred_noise_t, t, image, eta)["prev_sample"]
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, eta)["prev_sample"]
|
||||
|
||||
# scale and decode image with vae
|
||||
image = 1 / 0.18215 * image
|
||||
image = self.vqvae.decode(image)
|
||||
image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vqvae.decode(latents)
|
||||
|
||||
return image
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
@@ -13,12 +13,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
batch_size=1,
|
||||
generator=None,
|
||||
torch_device=None,
|
||||
eta=0.0,
|
||||
num_inference_steps=50,
|
||||
self, batch_size=1, generator=None, torch_device=None, eta=0.0, num_inference_steps=50, output_type="pil"
|
||||
):
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
@@ -28,25 +23,26 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
||||
self.unet.to(torch_device)
|
||||
self.vqvae.to(torch_device)
|
||||
|
||||
image = torch.randn(
|
||||
latents = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.image_size, self.unet.image_size),
|
||||
generator=generator,
|
||||
).to(torch_device)
|
||||
)
|
||||
latents = latents.to(torch_device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
with torch.no_grad():
|
||||
model_output = self.unet(image, t)
|
||||
# predict the noise residual
|
||||
noise_prediction = self.unet(latents, t)["sample"]
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_prediction, t, latents, eta)["prev_sample"]
|
||||
|
||||
if isinstance(model_output, dict):
|
||||
model_output = model_output["sample"]
|
||||
# decode the image latents with the VAE
|
||||
image = self.vqvae.decode(latents)
|
||||
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
# decode image with vae
|
||||
with torch.no_grad():
|
||||
image = self.vqvae.decode(image)
|
||||
return {"sample": image}
|
||||
|
||||
@@ -28,7 +28,7 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
|
||||
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50, output_type="pil"):
|
||||
# For more information on the sampling method you can take a look at Algorithm 2 of
|
||||
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
if torch_device is None:
|
||||
@@ -43,18 +43,20 @@ class PNDMPipeline(DiffusionPipeline):
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
|
||||
prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps)
|
||||
for t in tqdm(range(len(prk_time_steps))):
|
||||
t_orig = prk_time_steps[t]
|
||||
model_output = self.unet(image, t_orig)["sample"]
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
for i, t in enumerate(tqdm(self.scheduler.prk_timesteps)):
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
|
||||
image = self.scheduler.step_prk(model_output, t, image, num_inference_steps)["prev_sample"]
|
||||
image = self.scheduler.step_prk(model_output, i, image, num_inference_steps)["prev_sample"]
|
||||
|
||||
timesteps = self.scheduler.get_time_steps(num_inference_steps)
|
||||
for t in tqdm(range(len(timesteps))):
|
||||
t_orig = timesteps[t]
|
||||
model_output = self.unet(image, t_orig)["sample"]
|
||||
for i, t in enumerate(tqdm(self.scheduler.plms_timesteps)):
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
|
||||
image = self.scheduler.step_plms(model_output, t, image, num_inference_steps)["prev_sample"]
|
||||
image = self.scheduler.step_plms(model_output, i, image, num_inference_steps)["prev_sample"]
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@@ -2,15 +2,16 @@
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
|
||||
class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
def __init__(self, model, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler)
|
||||
|
||||
def __call__(self, num_inference_steps=2000, generator=None):
|
||||
@torch.no_grad()
|
||||
def __call__(self, num_inference_steps=2000, generator=None, output_type="pil"):
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
img_size = self.model.config.image_size
|
||||
@@ -24,12 +25,11 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.set_sigmas(num_inference_steps)
|
||||
|
||||
for i, t in enumerate(self.scheduler.timesteps):
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
|
||||
|
||||
for _ in range(self.scheduler.correct_steps):
|
||||
with torch.no_grad():
|
||||
model_output = self.model(sample, sigma_t)
|
||||
model_output = self.model(sample, sigma_t)
|
||||
|
||||
if isinstance(model_output, dict):
|
||||
model_output = model_output["sample"]
|
||||
@@ -45,4 +45,9 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
output = self.scheduler.step_pred(model_output, t, sample)
|
||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||
|
||||
return sample_mean
|
||||
sample = sample.clamp(0, 1)
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
sample = self.numpy_to_pil(sample)
|
||||
|
||||
return {"sample": sample}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
import pdb
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
@@ -71,8 +72,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.one = np.array(1.0)
|
||||
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
||||
@@ -82,49 +81,29 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.cur_model_output = 0
|
||||
self.cur_sample = None
|
||||
self.ets = []
|
||||
self.prk_time_steps = {}
|
||||
self.time_steps = {}
|
||||
self.set_prk_mode()
|
||||
|
||||
def get_prk_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.prk_time_steps:
|
||||
return self.prk_time_steps[num_inference_steps]
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||
self.prk_timesteps = None
|
||||
self.plms_timesteps = None
|
||||
|
||||
inference_step_times = list(
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = list(
|
||||
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
|
||||
)
|
||||
|
||||
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
prk_time_steps = np.array(self.timesteps[-self.pndm_order :]).repeat(2) + np.tile(
|
||||
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
|
||||
)
|
||||
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
|
||||
self.prk_timesteps = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
|
||||
self.plms_timesteps = list(reversed(self.timesteps[:-3]))
|
||||
|
||||
return self.prk_time_steps[num_inference_steps]
|
||||
|
||||
def get_time_steps(self, num_inference_steps):
|
||||
if num_inference_steps in self.time_steps:
|
||||
return self.time_steps[num_inference_steps]
|
||||
|
||||
inference_step_times = list(
|
||||
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
|
||||
)
|
||||
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
|
||||
|
||||
return self.time_steps[num_inference_steps]
|
||||
|
||||
def set_prk_mode(self):
|
||||
self.mode = "prk"
|
||||
|
||||
def set_plms_mode(self):
|
||||
self.mode = "plms"
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
if self.mode == "prk":
|
||||
return self.step_prk(*args, **kwargs)
|
||||
if self.mode == "plms":
|
||||
return self.step_plms(*args, **kwargs)
|
||||
|
||||
raise ValueError(f"mode {self.mode} does not exist.")
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def step_prk(
|
||||
self,
|
||||
@@ -138,7 +117,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
solution to the differential equation.
|
||||
"""
|
||||
t = timestep
|
||||
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
|
||||
prk_time_steps = self.prk_timesteps
|
||||
|
||||
t_orig = prk_time_steps[t // 4 * 4]
|
||||
t_orig_prev = prk_time_steps[min(t + 1, len(prk_time_steps) - 1)]
|
||||
@@ -180,7 +159,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"for more information."
|
||||
)
|
||||
|
||||
timesteps = self.get_time_steps(num_inference_steps)
|
||||
timesteps = self.plms_timesteps
|
||||
|
||||
t_orig = timesteps[t]
|
||||
t_orig_prev = timesteps[min(t + 1, len(timesteps) - 1)]
|
||||
|
||||
@@ -18,11 +18,11 @@ import inspect
|
||||
import math
|
||||
import tempfile
|
||||
import unittest
|
||||
from atexit import register
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from diffusers import UNetConditionalModel # noqa: F401 TODO(Patrick) - need to write tests with it
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
@@ -704,11 +704,11 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
image = ddpm(generator=generator)["sample"]
|
||||
image = ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
generator = generator.manual_seed(0)
|
||||
new_image = new_ddpm(generator=generator)["sample"]
|
||||
new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
@@ -722,11 +722,32 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
image = ddpm(generator=generator)["sample"]
|
||||
image = ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
generator = generator.manual_seed(0)
|
||||
new_image = ddpm_from_hub(generator=generator)["sample"]
|
||||
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
@slow
|
||||
def test_output_format(self):
|
||||
model_path = "google/ddpm-cifar10-32"
|
||||
|
||||
pipe = DDIMPipeline.from_pretrained(model_path)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
images = pipe(generator=generator, output_type="numpy")["sample"]
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
assert isinstance(images, np.ndarray)
|
||||
|
||||
images = pipe(generator=generator, output_type="pil")["sample"]
|
||||
assert isinstance(images, list)
|
||||
assert len(images) == 1
|
||||
assert isinstance(images[0], PIL.Image.Image)
|
||||
|
||||
# use PIL by default
|
||||
images = pipe(generator=generator)["sample"]
|
||||
assert isinstance(images, list)
|
||||
assert isinstance(images[0], PIL.Image.Image)
|
||||
|
||||
@slow
|
||||
def test_ddpm_cifar10(self):
|
||||
@@ -739,15 +760,13 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator)["sample"]
|
||||
image = ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.1601, -0.2823, -0.6123, -0.2305, -0.3236, -0.4706, -0.1691, -0.2836, -0.3231]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ddim_lsun(self):
|
||||
@@ -759,15 +778,13 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator)["sample"]
|
||||
image = ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.9879, -0.9598, -0.9312, -0.9953, -0.9963, -0.9995, -0.9957, -1.0000, -0.9863]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ddim_cifar10(self):
|
||||
@@ -779,15 +796,13 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddim(generator=generator, eta=0.0)["sample"]
|
||||
image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.6553, -0.6765, -0.6799, -0.6749, -0.7006, -0.6974, -0.6991, -0.7116, -0.7094]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_pndm_cifar10(self):
|
||||
@@ -798,15 +813,13 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||
generator = torch.manual_seed(0)
|
||||
image = pndm(generator=generator)["sample"]
|
||||
image = pndm(generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 32, 32)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.6872, -0.7071, -0.7188, -0.7057, -0.7515, -0.7191, -0.7377, -0.7565, -0.7500]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.1564, 0.14645, 0.1406, 0.14715, 0.12425, 0.14045, 0.13115, 0.12175, 0.125])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img(self):
|
||||
@@ -814,13 +827,15 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm([prompt], generator=generator, num_inference_steps=20)
|
||||
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[
|
||||
"sample"
|
||||
]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ldm_text2img_fast(self):
|
||||
@@ -828,55 +843,44 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm([prompt], generator=generator, num_inference_steps=1)
|
||||
image = ldm([prompt], generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-ffhq-1024")
|
||||
model = UNetUnconditionalModel.from_pretrained("google/ncsnpp-church-256")
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-ffhq-1024")
|
||||
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=2)
|
||||
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
|
||||
|
||||
if model.device.type == "cpu":
|
||||
# patrick's cpu
|
||||
expected_image_sum = 3384805888.0
|
||||
expected_image_mean = 1076.00085
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
# m1 mbp
|
||||
# expected_image_sum = 3384805376.0
|
||||
# expected_image_mean = 1076.000610351562
|
||||
else:
|
||||
expected_image_sum = 3382849024.0
|
||||
expected_image_mean = 1075.3788
|
||||
|
||||
assert (image.abs().sum() - expected_image_sum).abs().cpu().item() < 1e-2
|
||||
assert (image.abs().mean() - expected_image_mean).abs().cpu().item() < 1e-4
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_ldm_uncond(self):
|
||||
ldm = LatentDiffusionUncondPipeline.from_pretrained("CompVis/ldm-celebahq-256")
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=5)["sample"]
|
||||
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
|
||||
|
||||
image_slice = image[0, -1, -3:, -3:].cpu()
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor(
|
||||
[-0.1202, -0.1005, -0.0635, -0.0520, -0.1282, -0.0838, -0.0981, -0.1318, -0.1106]
|
||||
)
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.4399, 0.44975, 0.46825, 0.474, 0.4359, 0.4581, 0.45095, 0.4341, 0.4447])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@@ -70,7 +70,6 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
@@ -102,7 +101,6 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
@@ -375,33 +373,40 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def check_over_configs_pmls(self, time_step=0, **config):
|
||||
def check_over_configs(self, time_step=0, **config):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def check_over_forward_pmls(self, time_step=0, **forward_kwargs):
|
||||
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
pass
|
||||
|
||||
def check_over_forward(self, time_step=0, **forward_kwargs):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
kwargs.update(forward_kwargs)
|
||||
sample = self.dummy_sample
|
||||
@@ -409,74 +414,127 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
new_scheduler.set_plms_mode()
|
||||
new_scheduler.set_timesteps(kwargs["num_inference_steps"])
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
def test_pytorch_equal_numpy(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
|
||||
sample_pt = torch.tensor(sample)
|
||||
residual_pt = 0.1 * sample_pt
|
||||
dummy_past_residuals_pt = [residual_pt + 0.2, residual_pt + 0.15, residual_pt + 0.1, residual_pt + 0.05]
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
scheduler_pt = scheduler_class(tensor_format="pt", **scheduler_config)
|
||||
# copy over dummy past residuals
|
||||
scheduler_pt.ets = dummy_past_residuals_pt[:]
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
scheduler_pt.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
def test_step_shape(self):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
# copy over dummy past residuals
|
||||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output_0 = scheduler.step_prk(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_prk(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
output_0 = scheduler.step_plms(residual, 0, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_plms(residual, 1, sample, num_inference_steps, **kwargs)["prev_sample"]
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_timesteps_pmls(self):
|
||||
for timesteps in [100, 1000]:
|
||||
self.check_over_configs_pmls(num_train_timesteps=timesteps)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
|
||||
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_betas_pmls(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01], [0.002, 0.02, 0.2]):
|
||||
self.check_over_configs_pmls(beta_start=beta_start, beta_end=beta_end)
|
||||
|
||||
def test_schedules(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_schedules_pmls(self):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_time_indices_pmls(self):
|
||||
for t in [1, 5, 10]:
|
||||
self.check_over_forward_pmls(time_step=t)
|
||||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_inference_steps_pmls(self):
|
||||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward_pmls(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_inference_pmls_no_past_residuals(self):
|
||||
def test_inference_plms_no_past_residuals(self):
|
||||
with self.assertRaises(ValueError):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_plms_mode()
|
||||
|
||||
scheduler.step(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
|
||||
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample, 50)["prev_sample"]
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
@@ -486,20 +544,15 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
num_inference_steps = 10
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
prk_time_steps = scheduler.get_prk_time_steps(num_inference_steps)
|
||||
for t in range(len(prk_time_steps)):
|
||||
t_orig = prk_time_steps[t]
|
||||
residual = model(sample, t_orig)
|
||||
for i, t in enumerate(scheduler.prk_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_prk(residual, i, sample, num_inference_steps)["prev_sample"]
|
||||
|
||||
sample = scheduler.step_prk(residual, t, sample, num_inference_steps)["prev_sample"]
|
||||
|
||||
timesteps = scheduler.get_time_steps(num_inference_steps)
|
||||
for t in range(len(timesteps)):
|
||||
t_orig = timesteps[t]
|
||||
residual = model(sample, t_orig)
|
||||
|
||||
sample = scheduler.step_plms(residual, t, sample, num_inference_steps)["prev_sample"]
|
||||
for i, t in enumerate(scheduler.plms_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_plms(residual, i, sample, num_inference_steps)["prev_sample"]
|
||||
|
||||
result_sum = np.sum(np.abs(sample))
|
||||
result_mean = np.mean(np.abs(sample))
|
||||
@@ -562,7 +615,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
@@ -591,7 +643,6 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user