1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

V prediction ddim (#1313)

* v diffusion support for ddpm

* quality and style

* variable name consistency

* missing base case

* pass prediction type along in the pipeline

* put prediction type in scheduler config

* style

* try to train on ddim

* changes to ddim

* ddim v prediction works to train butterflies example

* fix bad merge, style and quality

* try to fix broken doc strings

* second pass

* one more

* white space

* Update src/diffusers/schedulers/scheduling_ddim.py

* remove extra lines

* Update src/diffusers/schedulers/scheduling_ddim.py

Co-authored-by: Ben Glickenhaus <ben@mail.cs.umass.edu>
Co-authored-by: Nathan Lambert <nathan@huggingface.co>
This commit is contained in:
Ben Glickenhaus
2022-11-17 13:26:19 -05:00
committed by GitHub
parent 56164f56fb
commit 11362ae5d2
2 changed files with 299 additions and 17 deletions

View File

@@ -0,0 +1,227 @@
import glob
import os
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_cosine_schedule_with_warmup
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
@dataclass
class TrainingConfig:
image_size = 128 # the generated image resolution
train_batch_size = 16
eval_batch_size = 16 # how many images to sample during evaluation
num_epochs = 50
gradient_accumulation_steps = 1
learning_rate = 5e-5
lr_warmup_steps = 500
save_image_epochs = 10
save_model_epochs = 30
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
output_dir = "ddim-butterflies-128-v-diffusion" # the model namy locally and on the HF Hub
push_to_hub = False # whether to upload the saved model to the HF Hub
hub_private_repo = False
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 0
config = TrainingConfig()
config.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name, split="train")
preprocess = transforms.Compose(
[
transforms.Resize((config.image_size, config.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images}
dataset.set_transform(transform)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
model = UNet2DModel(
sample_size=config.image_size, # the target image resolution
in_channels=3, # the number of input channels, 3 for RGB images
out_channels=3, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"DownBlock2D",
),
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
if config.output_dir.startswith("ddpm"):
noise_scheduler = DDPMScheduler(
num_train_timesteps=1000,
beta_schedule="squaredcos_cap_v2",
variance_type="v_diffusion",
prediction_type="v",
)
else:
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_schedule="squaredcos_cap_v2",
variance_type="v_diffusion",
prediction_type="v",
)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)
def make_grid(images, rows, cols):
w, h = images[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, image in enumerate(images):
grid.paste(image, box=(i % cols * w, i // cols * h))
return grid
def evaluate(config, epoch, pipeline):
# Sample some images from random noise (this is the backward diffusion process).
# The default pipeline output type is `List[PIL.Image]`
images = pipeline(
batch_size=config.eval_batch_size,
generator=torch.manual_seed(config.seed),
).images
# Make a grid out of the images
image_grid = make_grid(images, rows=4, cols=4)
# Save the images
test_dir = os.path.join(config.output_dir, "samples")
os.makedirs(test_dir, exist_ok=True)
image_grid.save(f"{test_dir}/{epoch:04d}.png")
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
# Initialize accelerator and tensorboard logging
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with="tensorboard",
logging_dir=os.path.join(config.output_dir, "logs"),
)
if accelerator.is_main_process:
if config.push_to_hub:
repo = init_git_repo(config, at_init=True)
accelerator.init_trackers("train_example")
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
global_step = 0
if config.output_dir.startswith("ddpm"):
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
else:
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
evaluate(config, 0, pipeline)
# Now you train the model
for epoch in range(config.num_epochs):
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["images"]
# Sample noise to add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bs = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()
with accelerator.accumulate(model):
# Predict the noise residual
alpha_t, sigma_t = noise_scheduler.get_alpha_sigma(clean_images, timesteps, accelerator.device)
z_t = alpha_t * clean_images + sigma_t * noise
noise_pred = model(z_t, timesteps).sample
v = alpha_t * noise - sigma_t * clean_images
loss = F.mse_loss(noise_pred, v)
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
global_step += 1
# After each epoch you optionally sample some demo images with evaluate() and save the model
if accelerator.is_main_process:
if config.output_dir.startswith("ddpm"):
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
else:
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
evaluate(config, epoch, pipeline)
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
if config.push_to_hub:
push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True)
else:
pipeline.save_pretrained(config.output_dir)
args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
train_loop(*args)
sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
Image.open(sample_images[-1])

View File

@@ -17,7 +17,7 @@
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import Literal, Optional, Tuple, Union
import numpy as np
import torch
@@ -27,6 +27,17 @@ from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
def expand_to_shape(input, timesteps, shape, device):
"""
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once.
"""
out = torch.gather(input.to(device), 0, timesteps.to(device))
reshape = [shape[0]] + [1] * (len(shape) - 1)
out = out.reshape(*reshape)
return out
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
class DDIMSchedulerOutput(BaseOutput):
@@ -75,6 +86,18 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
return torch.tensor(betas)
def t_to_alpha_sigma(num_diffusion_timesteps):
"""Returns the scaling factors for the clean image and for the noise, given
a timestep."""
alphas = torch.cos(
torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)])
)
sigmas = torch.sin(
torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)])
)
return alphas, sigmas
class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
@@ -128,7 +151,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
variance_type: str = "fixed",
steps_offset: int = 0,
prediction_type: Literal["epsilon", "sample", "v"] = "epsilon",
**kwargs,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
@@ -145,15 +171,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.variance_type = variance_type
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sigmas = 1 - self.alphas**2
if prediction_type == "v":
self.alphas, self.sigmas = t_to_alpha_sigma(num_train_timesteps)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
self.final_sigma = torch.tensor(0.0) if set_alpha_to_one else self.sigmas[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
@@ -161,6 +190,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
self.variance_type = variance_type
self.prediction_type = prediction_type
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
@@ -170,20 +201,31 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def _get_variance(self, timestep, prev_timestep):
def _get_variance(self, timestep, prev_timestep, eta=0):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
if self.variance_type == "fixed":
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
elif self.variance_type == "v_diffusion":
# If eta > 0, adjust the scaling factor for the predicted noise
# downward according to the amount of additional noise to add
alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
sigma_prev = self.sigmas[prev_timestep] if prev_timestep >= 0 else self.final_sigma
if eta:
numerator = eta * (sigma_prev**2 / self.sigmas[timestep] ** 2).clamp(min=1.0e-7).sqrt()
else:
numerator = 0
denominator = (1 - self.alphas[timestep] ** 2 / alpha_prev**2).clamp(min=1.0e-7).sqrt()
ddim_sigma = (numerator * denominator).clamp(min=1.0e-7)
variance = (sigma_prev**2 - ddim_sigma**2).clamp(min=1.0e-7).sqrt()
return variance
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
@@ -207,7 +249,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
prediction_type: str = "epsilon",
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
@@ -271,19 +312,21 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if prediction_type == "epsilon":
if self.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
eps = torch.tensor(1)
elif prediction_type == "sample":
elif self.prediction_type == "sample":
pred_original_sample = model_output
eps = torch.tensor(1)
elif prediction_type == "v":
elif self.prediction_type == "v":
# v_t = alpha_t * epsilon - sigma_t * x
# need to merge the PRs for sigma to be available in DDPM
pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep]
eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep]
eps = model_output * self.alphas[timestep] + sample * self.sigmas[timestep]
else:
raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`")
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`"
)
# 4. Clip "predicted x_0"
if self.config.clip_sample:
@@ -291,7 +334,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = self._get_variance(timestep, prev_timestep)
variance = self._get_variance(timestep, prev_timestep, eta)
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output:
@@ -299,10 +342,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
if self.prediction_type == "epsilon":
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction
else:
alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
prev_sample = pred_original_sample * alpha_prev + eps * variance
if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
@@ -325,7 +372,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample,)
@@ -337,6 +383,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
if self.variance_type == "v_diffusion":
alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device)
z_t = alpha * original_samples + sigma * noise
return z_t
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
@@ -356,3 +406,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def __len__(self):
return self.config.num_train_timesteps
def get_alpha_sigma(self, sample, timesteps, device):
alpha = expand_to_shape(self.alphas, timesteps, sample.shape, device)
sigma = expand_to_shape(self.sigmas, timesteps, sample.shape, device)
return alpha, sigma