mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
V prediction ddim (#1313)
* v diffusion support for ddpm * quality and style * variable name consistency * missing base case * pass prediction type along in the pipeline * put prediction type in scheduler config * style * try to train on ddim * changes to ddim * ddim v prediction works to train butterflies example * fix bad merge, style and quality * try to fix broken doc strings * second pass * one more * white space * Update src/diffusers/schedulers/scheduling_ddim.py * remove extra lines * Update src/diffusers/schedulers/scheduling_ddim.py Co-authored-by: Ben Glickenhaus <ben@mail.cs.umass.edu> Co-authored-by: Nathan Lambert <nathan@huggingface.co>
This commit is contained in:
227
examples/v_prediction/train_butterflies.py
Normal file
227
examples/v_prediction/train_butterflies.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import glob
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_cosine_schedule_with_warmup
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
image_size = 128 # the generated image resolution
|
||||
train_batch_size = 16
|
||||
eval_batch_size = 16 # how many images to sample during evaluation
|
||||
num_epochs = 50
|
||||
gradient_accumulation_steps = 1
|
||||
learning_rate = 5e-5
|
||||
lr_warmup_steps = 500
|
||||
save_image_epochs = 10
|
||||
save_model_epochs = 30
|
||||
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
|
||||
output_dir = "ddim-butterflies-128-v-diffusion" # the model namy locally and on the HF Hub
|
||||
|
||||
push_to_hub = False # whether to upload the saved model to the HF Hub
|
||||
hub_private_repo = False
|
||||
overwrite_output_dir = True # overwrite the old model when re-running the notebook
|
||||
seed = 0
|
||||
|
||||
|
||||
config = TrainingConfig()
|
||||
|
||||
|
||||
config.dataset_name = "huggan/smithsonian_butterflies_subset"
|
||||
dataset = load_dataset(config.dataset_name, split="train")
|
||||
|
||||
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((config.image_size, config.image_size)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def transform(examples):
|
||||
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"images": images}
|
||||
|
||||
|
||||
dataset.set_transform(transform)
|
||||
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
|
||||
|
||||
|
||||
model = UNet2DModel(
|
||||
sample_size=config.image_size, # the target image resolution
|
||||
in_channels=3, # the number of input channels, 3 for RGB images
|
||||
out_channels=3, # the number of output channels
|
||||
layers_per_block=2, # how many ResNet layers to use per UNet block
|
||||
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block
|
||||
down_block_types=(
|
||||
"DownBlock2D", # a regular ResNet downsampling block
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types=(
|
||||
"UpBlock2D", # a regular ResNet upsampling block
|
||||
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if config.output_dir.startswith("ddpm"):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
variance_type="v_diffusion",
|
||||
prediction_type="v",
|
||||
)
|
||||
else:
|
||||
noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
variance_type="v_diffusion",
|
||||
prediction_type="v",
|
||||
)
|
||||
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
|
||||
lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=config.lr_warmup_steps,
|
||||
num_training_steps=(len(train_dataloader) * config.num_epochs),
|
||||
)
|
||||
|
||||
|
||||
def make_grid(images, rows, cols):
|
||||
w, h = images[0].size
|
||||
grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
for i, image in enumerate(images):
|
||||
grid.paste(image, box=(i % cols * w, i // cols * h))
|
||||
return grid
|
||||
|
||||
|
||||
def evaluate(config, epoch, pipeline):
|
||||
# Sample some images from random noise (this is the backward diffusion process).
|
||||
# The default pipeline output type is `List[PIL.Image]`
|
||||
images = pipeline(
|
||||
batch_size=config.eval_batch_size,
|
||||
generator=torch.manual_seed(config.seed),
|
||||
).images
|
||||
|
||||
# Make a grid out of the images
|
||||
image_grid = make_grid(images, rows=4, cols=4)
|
||||
|
||||
# Save the images
|
||||
test_dir = os.path.join(config.output_dir, "samples")
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
image_grid.save(f"{test_dir}/{epoch:04d}.png")
|
||||
|
||||
|
||||
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
|
||||
# Initialize accelerator and tensorboard logging
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=config.mixed_precision,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
log_with="tensorboard",
|
||||
logging_dir=os.path.join(config.output_dir, "logs"),
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
if config.push_to_hub:
|
||||
repo = init_git_repo(config, at_init=True)
|
||||
accelerator.init_trackers("train_example")
|
||||
|
||||
# Prepare everything
|
||||
# There is no specific order to remember, you just need to unpack the
|
||||
# objects in the same order you gave them to the prepare method.
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
|
||||
if config.output_dir.startswith("ddpm"):
|
||||
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||
else:
|
||||
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||
|
||||
evaluate(config, 0, pipeline)
|
||||
|
||||
# Now you train the model
|
||||
for epoch in range(config.num_epochs):
|
||||
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description(f"Epoch {epoch}")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
clean_images = batch["images"]
|
||||
# Sample noise to add to the images
|
||||
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
bs = clean_images.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()
|
||||
|
||||
with accelerator.accumulate(model):
|
||||
# Predict the noise residual
|
||||
alpha_t, sigma_t = noise_scheduler.get_alpha_sigma(clean_images, timesteps, accelerator.device)
|
||||
z_t = alpha_t * clean_images + sigma_t * noise
|
||||
noise_pred = model(z_t, timesteps).sample
|
||||
v = alpha_t * noise - sigma_t * clean_images
|
||||
loss = F.mse_loss(noise_pred, v)
|
||||
accelerator.backward(loss)
|
||||
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
progress_bar.update(1)
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
global_step += 1
|
||||
|
||||
# After each epoch you optionally sample some demo images with evaluate() and save the model
|
||||
if accelerator.is_main_process:
|
||||
if config.output_dir.startswith("ddpm"):
|
||||
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||
else:
|
||||
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||
|
||||
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
|
||||
evaluate(config, epoch, pipeline)
|
||||
|
||||
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
|
||||
if config.push_to_hub:
|
||||
push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True)
|
||||
else:
|
||||
pipeline.save_pretrained(config.output_dir)
|
||||
|
||||
|
||||
args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
train_loop(*args)
|
||||
|
||||
sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
|
||||
Image.open(sample_images[-1])
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -27,6 +27,17 @@ from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
def expand_to_shape(input, timesteps, shape, device):
|
||||
"""
|
||||
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
|
||||
nicely with `shape`. Useful for parellizing operations over `shape[0]` number of diffusion steps at once.
|
||||
"""
|
||||
out = torch.gather(input.to(device), 0, timesteps.to(device))
|
||||
reshape = [shape[0]] + [1] * (len(shape) - 1)
|
||||
out = out.reshape(*reshape)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
|
||||
class DDIMSchedulerOutput(BaseOutput):
|
||||
@@ -75,6 +86,18 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
||||
return torch.tensor(betas)
|
||||
|
||||
|
||||
def t_to_alpha_sigma(num_diffusion_timesteps):
|
||||
"""Returns the scaling factors for the clean image and for the noise, given
|
||||
a timestep."""
|
||||
alphas = torch.cos(
|
||||
torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)])
|
||||
)
|
||||
sigmas = torch.sin(
|
||||
torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)])
|
||||
)
|
||||
return alphas, sigmas
|
||||
|
||||
|
||||
class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
||||
@@ -128,7 +151,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
variance_type: str = "fixed",
|
||||
steps_offset: int = 0,
|
||||
prediction_type: Literal["epsilon", "sample", "v"] = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
@@ -145,15 +171,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.variance_type = variance_type
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
self.sigmas = 1 - self.alphas**2
|
||||
if prediction_type == "v":
|
||||
self.alphas, self.sigmas = t_to_alpha_sigma(num_train_timesteps)
|
||||
|
||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
self.final_sigma = torch.tensor(0.0) if set_alpha_to_one else self.sigmas[0]
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
@@ -161,6 +190,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||
self.variance_type = variance_type
|
||||
self.prediction_type = prediction_type
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -170,20 +201,31 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
def _get_variance(self, timestep, prev_timestep, eta=0):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
|
||||
if self.variance_type == "fixed":
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
elif self.variance_type == "v_diffusion":
|
||||
# If eta > 0, adjust the scaling factor for the predicted noise
|
||||
# downward according to the amount of additional noise to add
|
||||
alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
sigma_prev = self.sigmas[prev_timestep] if prev_timestep >= 0 else self.final_sigma
|
||||
if eta:
|
||||
numerator = eta * (sigma_prev**2 / self.sigmas[timestep] ** 2).clamp(min=1.0e-7).sqrt()
|
||||
else:
|
||||
numerator = 0
|
||||
denominator = (1 - self.alphas[timestep] ** 2 / alpha_prev**2).clamp(min=1.0e-7).sqrt()
|
||||
ddim_sigma = (numerator * denominator).clamp(min=1.0e-7)
|
||||
variance = (sigma_prev**2 - ddim_sigma**2).clamp(min=1.0e-7).sqrt()
|
||||
return variance
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
@@ -207,7 +249,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
prediction_type: str = "epsilon",
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
@@ -271,19 +312,21 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
if prediction_type == "epsilon":
|
||||
if self.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
eps = torch.tensor(1)
|
||||
elif prediction_type == "sample":
|
||||
elif self.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
eps = torch.tensor(1)
|
||||
elif prediction_type == "v":
|
||||
elif self.prediction_type == "v":
|
||||
# v_t = alpha_t * epsilon - sigma_t * x
|
||||
# need to merge the PRs for sigma to be available in DDPM
|
||||
pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep]
|
||||
eps = model_output * self.alphas[timestep] - sample * self.sigmas[timestep]
|
||||
eps = model_output * self.alphas[timestep] + sample * self.sigmas[timestep]
|
||||
else:
|
||||
raise ValueError(f"prediction_type given as {prediction_type} must be one of `epsilon`, `sample`, or `v`")
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v`"
|
||||
)
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
@@ -291,7 +334,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = self._get_variance(timestep, prev_timestep)
|
||||
variance = self._get_variance(timestep, prev_timestep, eta)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
if use_clipped_model_output:
|
||||
@@ -299,10 +342,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
||||
if self.prediction_type == "epsilon":
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
||||
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction
|
||||
else:
|
||||
alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
prev_sample = pred_original_sample * alpha_prev + eps * variance
|
||||
|
||||
if eta > 0:
|
||||
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
|
||||
@@ -325,7 +372,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
|
||||
|
||||
prev_sample = prev_sample + variance
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -337,6 +383,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
if self.variance_type == "v_diffusion":
|
||||
alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device)
|
||||
z_t = alpha * original_samples + sigma * noise
|
||||
return z_t
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
@@ -356,3 +406,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def get_alpha_sigma(self, sample, timesteps, device):
|
||||
alpha = expand_to_shape(self.alphas, timesteps, sample.shape, device)
|
||||
sigma = expand_to_shape(self.sigmas, timesteps, sample.shape, device)
|
||||
return alpha, sigma
|
||||
|
||||
Reference in New Issue
Block a user