1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Patrick von Platen
2022-06-13 16:03:49 +00:00
11 changed files with 369 additions and 119 deletions

131
README.md
View File

@@ -164,7 +164,7 @@ image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save("test.png")
```
**Text to Image generation with Latent Diffusion**
#### **Text to Image generation with Latent Diffusion**
```python
from diffusers import DiffusionPipeline
@@ -184,59 +184,98 @@ image_pil = PIL.Image.fromarray(image_processed[0])
# save image
image_pil.save("test.png")
```
#### **Text to speech with BDDM**
_Follow the isnstructions [here](https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/) to load tacotron2 model._
```python
import torch
from diffusers import BDDM, DiffusionPipeline
torch_device = "cuda"
# load the BDDM pipeline
bddm = DiffusionPipeline.from_pretrained("fusing/diffwave-vocoder")
# load tacotron2 to get the mel spectograms
tacotron2 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tacotron2', model_math='fp16')
tacotron2 = tacotron2.to(torch_device).eval()
text = "Hello world, I missed you so much."
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tts_utils')
sequences, lengths = utils.prepare_input_sequence([text])
# generate mel spectograms using text
with torch.no_grad():
mel_spec, _, _ = tacotron2.infer(sequences, lengths)
# generate the speech by passing mel spectograms to BDDM pipeline
generator = torch.manual_seed(0)
audio = bddm(mel_spec, generator, torch_device)
# save generated audio
from scipy.io.wavfile import write as wavwrite
sampling_rate = 22050
wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
```
## Library structure:
```
├── models
│   ├── audio
│   │   └── fastdiff
│   │   ├── modeling_fastdiff.py
│   │   ├── README.md
│   │   └── run_fastdiff.py
│   ├── __init__.py
│   └── vision
│   ├── dalle2
│   │   ├── modeling_dalle2.py
│   │   ├── README.md
│   │   └── run_dalle2.py
│   ├── ddpm
│   │   ├── example.py
│   │   ├── modeling_ddpm.py
│   │   ├── README.md
│   │   └── run_ddpm.py
│   ├── glide
│   │   ├── modeling_glide.py
│   │   ├── modeling_vqvae.py.py
│   │   ├── README.md
│   │   └── run_glide.py
│   ├── imagen
│   │   ├── modeling_dalle2.py
│   │   ├── README.md
│   │   └── run_dalle2.py
│   ├── __init__.py
│   └── latent_diffusion
│   ├── modeling_latent_diffusion.py
│   ├── README.md
│   └── run_latent_diffusion.py
├── pyproject.toml
├── LICENSE
├── Makefile
├── README.md
├── pyproject.toml
├── setup.cfg
├── setup.py
├── src
  ── diffusers
   ├── configuration_utils.py
   ├── __init__.py
   ├── modeling_utils.py
   ├── models
   │   ├── __init__.py
   │   ├── unet_glide.py
   │   └── unet.py
   ├── pipeline_utils.py
   └── schedulers
   ├── gaussian_ddpm.py
   ├── __init__.py
── diffusers
├── __init__.py
├── configuration_utils.py
├── dependency_versions_check.py
├── dependency_versions_table.py
├── dynamic_modules_utils.py
├── modeling_utils.py
├── models
├── __init__.py
│ ├── unet.py
├── unet_glide.py
│ └── unet_ldm.py
│ ├── pipeline_utils.py
│ ├── pipelines
│ │ ├── __init__.py
│ │ ├── configuration_ldmbert.py
│ │ ├── conversion_glide.py
│ │ ├── modeling_vae.py
│ │ ├── pipeline_bddm.py
│ │ ├── pipeline_ddim.py
│ │ ├── pipeline_ddpm.py
│ │ ├── pipeline_glide.py
│ │ └── pipeline_latent_diffusion.py
│ ├── schedulers
│ │ ├── __init__.py
│ │ ├── classifier_free_guidance.py
│ │ ├── scheduling_ddim.py
│ │ ├── scheduling_ddpm.py
│ │ ├── scheduling_plms.py
│ │ └── scheduling_utils.py
│ ├── testing_utils.py
│ └── utils
│ ├── __init__.py
│ └── logging.py
├── tests
   └── test_modeling_utils.py
├── __init__.py
│ ├── test_modeling_utils.py
│ └── test_scheduler.py
└── utils
├── check_config_docstrings.py
├── check_copies.py
├── check_dummies.py
├── check_inits.py
├── check_repo.py
├── check_table.py
└── check_tf_ops.py
```

View File

@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler

View File

@@ -225,11 +225,11 @@ class ConfigMixin:
text = reader.read()
return json.loads(text)
def __eq__(self, other):
return self.__dict__ == other.__dict__
# def __eq__(self, other):
# return self.__dict__ == other.__dict__
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
# def __repr__(self):
# return f"{self.__class__.__name__} {self.to_json_string()}"
@property
def config(self) -> Dict[str, Any]:

View File

@@ -3,3 +3,4 @@ from .pipeline_ddpm import DDPM
from .pipeline_pndm import PNDM
from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_bddm import BDDM

View File

@@ -97,7 +97,9 @@ superres_model = GLIDESuperResUNetModel(
superres_model.load_state_dict(ups_state_dict, strict=False)
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02)
upscale_scheduler = DDIMScheduler(
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
)
glide = GLIDE(
text_unet=text2im_model,

View File

@@ -13,11 +13,18 @@
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
"""
@@ -41,8 +48,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
_embed = np.log(10000) / (half_dim - 1)
_embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
_embed = diffusion_steps * _embed
diffusion_step_embed = torch.cat((torch.sin(_embed),
torch.cos(_embed)), 1)
diffusion_step_embed = torch.cat((torch.sin(_embed), torch.cos(_embed)), 1)
return diffusion_step_embed
@@ -62,8 +68,7 @@ class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
super().__init__()
self.padding = dilation * (kernel_size - 1) // 2
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
dilation=dilation, padding=self.padding)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding)
self.conv = nn.utils.weight_norm(self.conv)
nn.init.kaiming_normal_(self.conv.weight)
@@ -89,8 +94,7 @@ class ZeroConv1d(nn.Module):
# every residual block (named residual layer in paper)
# contains one noncausal dilated conv
class ResidualBlock(nn.Module):
def __init__(self, res_channels, skip_channels, dilation,
diffusion_step_embed_dim_out):
def __init__(self, res_channels, skip_channels, dilation, diffusion_step_embed_dim_out):
super().__init__()
self.res_channels = res_channels
@@ -98,15 +102,12 @@ class ResidualBlock(nn.Module):
self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)
# Dilated conv layer
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels,
kernel_size=3, dilation=dilation)
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation)
# Add mel spectrogram upsampler and conditioner conv1x1 layer
self.upsample_conv2d = nn.ModuleList()
for s in [16, 16]:
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s),
padding=(1, s // 2),
stride=(1, s))
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s))
conv_trans2d = nn.utils.weight_norm(conv_trans2d)
nn.init.kaiming_normal_(conv_trans2d.weight)
self.upsample_conv2d.append(conv_trans2d)
@@ -152,7 +153,7 @@ class ResidualBlock(nn.Module):
h += mel_spec
# Gated-tanh nonlinearity
out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])
out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels :, :])
# Residual and skip outputs
res = self.res_conv(out)
@@ -164,10 +165,16 @@ class ResidualBlock(nn.Module):
class ResidualGroup(nn.Module):
def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out):
def __init__(
self,
res_channels,
skip_channels,
num_res_layers,
dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out,
):
super().__init__()
self.num_res_layers = num_res_layers
self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
@@ -180,16 +187,19 @@ class ResidualGroup(nn.Module):
self.residual_blocks = nn.ModuleList()
for n in range(self.num_res_layers):
self.residual_blocks.append(
ResidualBlock(res_channels, skip_channels,
dilation=2 ** (n % dilation_cycle),
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out))
ResidualBlock(
res_channels,
skip_channels,
dilation=2 ** (n % dilation_cycle),
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
)
)
def forward(self, input_data):
x, mel_spectrogram, diffusion_steps = input_data
# Embed diffusion step t
diffusion_step_embed = calc_diffusion_step_embedding(
diffusion_steps, self.diffusion_step_embed_dim_in)
diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in)
diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
@@ -206,27 +216,52 @@ class ResidualGroup(nn.Module):
return skip * math.sqrt(1.0 / self.num_res_layers)
class DiffWave(nn.Module):
def __init__(self, in_channels, res_channels, skip_channels, out_channels,
num_res_layers, dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out):
class DiffWave(ModelMixin, ConfigMixin):
def __init__(
self,
in_channels=1,
res_channels=128,
skip_channels=128,
out_channels=1,
num_res_layers=30,
dilation_cycle=10,
diffusion_step_embed_dim_in=128,
diffusion_step_embed_dim_mid=512,
diffusion_step_embed_dim_out=512,
):
super().__init__()
# register all init arguments with self.register
self.register(
in_channels=in_channels,
res_channels=res_channels,
skip_channels=skip_channels,
out_channels=out_channels,
num_res_layers=num_res_layers,
dilation_cycle=dilation_cycle,
diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
)
# Initial conv1x1 with relu
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
# All residual layers
self.residual_layer = ResidualGroup(res_channels,
skip_channels,
num_res_layers,
dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out)
self.residual_layer = ResidualGroup(
res_channels,
skip_channels,
num_res_layers,
dilation_cycle,
diffusion_step_embed_dim_in,
diffusion_step_embed_dim_mid,
diffusion_step_embed_dim_out,
)
# Final conv1x1 -> relu -> zeroconv1x1
self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels))
self.final_conv = nn.Sequential(
Conv(skip_channels, skip_channels, kernel_size=1),
nn.ReLU(inplace=False),
ZeroConv1d(skip_channels, out_channels),
)
def forward(self, input_data):
audio, mel_spectrogram, diffusion_steps = input_data
@@ -234,3 +269,45 @@ class DiffWave(nn.Module):
x = self.init_conv(x).clone()
x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
return self.final_conv(x)
class BDDM(DiffusionPipeline):
def __init__(self, diffwave, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(self, mel_spectrogram, generator, torch_device=None):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.diffwave.to(torch_device)
mel_spectrogram = mel_spectrogram.to(torch_device)
audio_length = mel_spectrogram.size(-1) * 256
audio_size = (1, 1, audio_length)
# Sample gaussian noise to begin loop
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
timestep_values = self.noise_scheduler.timestep_values
num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual
ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
residual = self.diffwave((audio, mel_spectrogram, ts))
# 2. predict previous mean of audio x_t-1
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
# 3. optionally sample variance
variance = 0
if t > 0:
noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
# 4. set current audio to prev_audio: x_t -> x_t-1
audio = pred_prev_audio + variance
return audio

View File

@@ -28,13 +28,7 @@ from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline
@@ -872,31 +866,31 @@ class GLIDE(DiffusionPipeline):
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
(
batch_size,
self.upscale_unet.in_channels // 2,
self.upscale_unet.resolution,
self.upscale_unet.resolution,
),
generator=generator,
)
image = image.to(torch_device)
image = image.to(torch_device) * upsample_temp
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
# adapt the beta schedule to the number of steps
# self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
# 1. predict noise residual
with torch.no_grad():
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
model_output = self.upscale_unet(image, time_input, low_res)
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
# 2. predict previous mean of image x_t-1
pred_prev_image = self.upscale_noise_scheduler.step(
noise_residual, image, t, num_inference_steps_upscale, eta
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
)
# 3. optionally sample variance
@@ -910,6 +904,6 @@ class GLIDE(DiffusionPipeline):
# 4. set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
image = image.permute(0, 2, 3, 1)
image = image.clamp(-1, 1).permute(0, 2, 3, 1)
return image

View File

@@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
clip_predicted_image=True,
tensor_format="np",
):
@@ -37,6 +39,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule=beta_schedule,
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image
if beta_schedule == "linear":
@@ -69,14 +72,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
def rescale_betas(self, num_timesteps):
if self.beta_schedule == "linear":
scale = self.timesteps / num_timesteps
self.betas = linear_beta_schedule(
num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
# def rescale_betas(self, num_timesteps):
# # GLIDE scaling
# if self.beta_schedule == "linear":
# scale = self.timesteps / num_timesteps
# self.betas = linear_beta_schedule(
# num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
# )
# self.alphas = 1.0 - self.betas
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
def get_alpha(self, time_step):
return self.alphas[time_step]
@@ -107,7 +111,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance
def step(self, residual, image, t, num_inference_steps, eta):
def step(self, residual, image, t, num_inference_steps, eta, use_clipped_residual=False):
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
@@ -141,6 +145,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
variance = self.get_variance(t, num_inference_steps)
std_dev_t = eta * variance ** (0.5)
if use_clipped_residual:
# the residual is always re-derived from the clipped x_0 in GLIDE
residual = (image - alpha_prod_t ** (0.5) * pred_original_image) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual

View File

@@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
variance_type="fixed_small",
clip_predicted_image=True,
tensor_format="np",
@@ -36,14 +38,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
variance_type=variance_type,
clip_predicted_image=clip_predicted_image,
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image
self.variance_type = variance_type
if beta_schedule == "linear":
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
@@ -56,6 +63,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1 - self.alphas_cumprod)
self.one = np.array(1.0)
self.set_format(tensor_format=tensor_format)
@@ -131,5 +140,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_image
def forward_step(self, original_image, noise, t):
noisy_image = self.sqrt_alphas_cumprod[t] * original_image + self.sqrt_one_minus_alphas_cumprod[t] * noise
return noisy_image
def __len__(self):
return self.timesteps

View File

@@ -0,0 +1,116 @@
import random
import numpy as np
import torch
import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel
from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
def set_seed(seed):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
set_seed(0)
accelerator = Accelerator(mixed_precision="fp16")
model = UNetModel(ch=128, ch_mult=(1, 2, 4, 8), resolution=64)
noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
num_epochs = 100
batch_size = 8
gradient_accumulation_steps = 8
augmentations = Compose(
[
Resize(64),
CenterCrop(64),
RandomHorizontalFlip(),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
]
)
dataset = load_dataset("huggan/pokemon", split="train")
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
dataset = dataset.shuffle(seed=0)
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=1000,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
for epoch in range(num_epochs):
model.train()
pbar = tqdm(total=len(train_dataloader), unit="ba")
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
for idx in range(bsz):
noise = torch.randn_like(clean_images[0]).to(clean_images.device)
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
if step % gradient_accumulation_steps == 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images)
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
loss = F.l1_loss(output, clean_images)
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
optimizer.step()
# eval
model.eval()
with torch.no_grad():
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
generator = torch.Generator()
generator = generator.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator)
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.type(torch.uint8).numpy()
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
pipeline.save_pretrained("./poke-ddpm")
image_pil.save(f"./poke-ddpm/test_{epoch}.png")