mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
merge
This commit is contained in:
131
README.md
131
README.md
@@ -164,7 +164,7 @@ image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
image_pil.save("test.png")
|
||||
```
|
||||
|
||||
**Text to Image generation with Latent Diffusion**
|
||||
#### **Text to Image generation with Latent Diffusion**
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -184,59 +184,98 @@ image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# save image
|
||||
image_pil.save("test.png")
|
||||
```
|
||||
|
||||
#### **Text to speech with BDDM**
|
||||
|
||||
_Follow the isnstructions [here](https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/) to load tacotron2 model._
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import BDDM, DiffusionPipeline
|
||||
|
||||
torch_device = "cuda"
|
||||
|
||||
# load the BDDM pipeline
|
||||
bddm = DiffusionPipeline.from_pretrained("fusing/diffwave-vocoder")
|
||||
|
||||
# load tacotron2 to get the mel spectograms
|
||||
tacotron2 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tacotron2', model_math='fp16')
|
||||
tacotron2 = tacotron2.to(torch_device).eval()
|
||||
|
||||
text = "Hello world, I missed you so much."
|
||||
|
||||
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tts_utils')
|
||||
sequences, lengths = utils.prepare_input_sequence([text])
|
||||
|
||||
# generate mel spectograms using text
|
||||
with torch.no_grad():
|
||||
mel_spec, _, _ = tacotron2.infer(sequences, lengths)
|
||||
|
||||
# generate the speech by passing mel spectograms to BDDM pipeline
|
||||
generator = torch.manual_seed(0)
|
||||
audio = bddm(mel_spec, generator, torch_device)
|
||||
|
||||
# save generated audio
|
||||
from scipy.io.wavfile import write as wavwrite
|
||||
sampling_rate = 22050
|
||||
wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy())
|
||||
```
|
||||
|
||||
## Library structure:
|
||||
|
||||
```
|
||||
├── models
|
||||
│ ├── audio
|
||||
│ │ └── fastdiff
|
||||
│ │ ├── modeling_fastdiff.py
|
||||
│ │ ├── README.md
|
||||
│ │ └── run_fastdiff.py
|
||||
│ ├── __init__.py
|
||||
│ └── vision
|
||||
│ ├── dalle2
|
||||
│ │ ├── modeling_dalle2.py
|
||||
│ │ ├── README.md
|
||||
│ │ └── run_dalle2.py
|
||||
│ ├── ddpm
|
||||
│ │ ├── example.py
|
||||
│ │ ├── modeling_ddpm.py
|
||||
│ │ ├── README.md
|
||||
│ │ └── run_ddpm.py
|
||||
│ ├── glide
|
||||
│ │ ├── modeling_glide.py
|
||||
│ │ ├── modeling_vqvae.py.py
|
||||
│ │ ├── README.md
|
||||
│ │ └── run_glide.py
|
||||
│ ├── imagen
|
||||
│ │ ├── modeling_dalle2.py
|
||||
│ │ ├── README.md
|
||||
│ │ └── run_dalle2.py
|
||||
│ ├── __init__.py
|
||||
│ └── latent_diffusion
|
||||
│ ├── modeling_latent_diffusion.py
|
||||
│ ├── README.md
|
||||
│ └── run_latent_diffusion.py
|
||||
├── pyproject.toml
|
||||
├── LICENSE
|
||||
├── Makefile
|
||||
├── README.md
|
||||
├── pyproject.toml
|
||||
├── setup.cfg
|
||||
├── setup.py
|
||||
├── src
|
||||
│ └── diffusers
|
||||
│ ├── configuration_utils.py
|
||||
│ ├── __init__.py
|
||||
│ ├── modeling_utils.py
|
||||
│ ├── models
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── unet_glide.py
|
||||
│ │ └── unet.py
|
||||
│ ├── pipeline_utils.py
|
||||
│ └── schedulers
|
||||
│ ├── gaussian_ddpm.py
|
||||
│ ├── __init__.py
|
||||
│ ├── diffusers
|
||||
│ ├── __init__.py
|
||||
│ ├── configuration_utils.py
|
||||
│ ├── dependency_versions_check.py
|
||||
│ ├── dependency_versions_table.py
|
||||
│ ├── dynamic_modules_utils.py
|
||||
│ ├── modeling_utils.py
|
||||
│ ├── models
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── unet.py
|
||||
│ │ ├── unet_glide.py
|
||||
│ │ └── unet_ldm.py
|
||||
│ ├── pipeline_utils.py
|
||||
│ ├── pipelines
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── configuration_ldmbert.py
|
||||
│ │ ├── conversion_glide.py
|
||||
│ │ ├── modeling_vae.py
|
||||
│ │ ├── pipeline_bddm.py
|
||||
│ │ ├── pipeline_ddim.py
|
||||
│ │ ├── pipeline_ddpm.py
|
||||
│ │ ├── pipeline_glide.py
|
||||
│ │ └── pipeline_latent_diffusion.py
|
||||
│ ├── schedulers
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── classifier_free_guidance.py
|
||||
│ │ ├── scheduling_ddim.py
|
||||
│ │ ├── scheduling_ddpm.py
|
||||
│ │ ├── scheduling_plms.py
|
||||
│ │ └── scheduling_utils.py
|
||||
│ ├── testing_utils.py
|
||||
│ └── utils
|
||||
│ ├── __init__.py
|
||||
│ └── logging.py
|
||||
├── tests
|
||||
│ └── test_modeling_utils.py
|
||||
│ ├── __init__.py
|
||||
│ ├── test_modeling_utils.py
|
||||
│ └── test_scheduler.py
|
||||
└── utils
|
||||
├── check_config_docstrings.py
|
||||
├── check_copies.py
|
||||
├── check_dummies.py
|
||||
├── check_inits.py
|
||||
├── check_repo.py
|
||||
├── check_table.py
|
||||
└── check_tf_ops.py
|
||||
```
|
||||
|
||||
@@ -9,6 +9,6 @@ from .models.unet import UNetModel
|
||||
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM
|
||||
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler
|
||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
|
||||
@@ -225,11 +225,11 @@ class ConfigMixin:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
# def __eq__(self, other):
|
||||
# return self.__dict__ == other.__dict__
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
# def __repr__(self):
|
||||
# return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
@property
|
||||
def config(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -3,3 +3,4 @@ from .pipeline_ddpm import DDPM
|
||||
from .pipeline_pndm import PNDM
|
||||
from .pipeline_glide import GLIDE
|
||||
from .pipeline_latent_diffusion import LatentDiffusion
|
||||
from .pipeline_bddm import BDDM
|
||||
|
||||
@@ -97,7 +97,9 @@ superres_model = GLIDESuperResUNetModel(
|
||||
|
||||
superres_model.load_state_dict(ups_state_dict, strict=False)
|
||||
|
||||
upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02)
|
||||
upscale_scheduler = DDIMScheduler(
|
||||
timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt"
|
||||
)
|
||||
|
||||
glide = GLIDE(
|
||||
text_unet=text2im_model,
|
||||
|
||||
@@ -13,11 +13,18 @@
|
||||
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import tqdm
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
|
||||
"""
|
||||
@@ -41,8 +48,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
|
||||
_embed = np.log(10000) / (half_dim - 1)
|
||||
_embed = torch.exp(torch.arange(half_dim) * -_embed).cuda()
|
||||
_embed = diffusion_steps * _embed
|
||||
diffusion_step_embed = torch.cat((torch.sin(_embed),
|
||||
torch.cos(_embed)), 1)
|
||||
diffusion_step_embed = torch.cat((torch.sin(_embed), torch.cos(_embed)), 1)
|
||||
return diffusion_step_embed
|
||||
|
||||
|
||||
@@ -62,8 +68,7 @@ class Conv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
|
||||
super().__init__()
|
||||
self.padding = dilation * (kernel_size - 1) // 2
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size,
|
||||
dilation=dilation, padding=self.padding)
|
||||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding)
|
||||
self.conv = nn.utils.weight_norm(self.conv)
|
||||
nn.init.kaiming_normal_(self.conv.weight)
|
||||
|
||||
@@ -89,8 +94,7 @@ class ZeroConv1d(nn.Module):
|
||||
# every residual block (named residual layer in paper)
|
||||
# contains one noncausal dilated conv
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, res_channels, skip_channels, dilation,
|
||||
diffusion_step_embed_dim_out):
|
||||
def __init__(self, res_channels, skip_channels, dilation, diffusion_step_embed_dim_out):
|
||||
super().__init__()
|
||||
self.res_channels = res_channels
|
||||
|
||||
@@ -98,15 +102,12 @@ class ResidualBlock(nn.Module):
|
||||
self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels)
|
||||
|
||||
# Dilated conv layer
|
||||
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels,
|
||||
kernel_size=3, dilation=dilation)
|
||||
self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation)
|
||||
|
||||
# Add mel spectrogram upsampler and conditioner conv1x1 layer
|
||||
self.upsample_conv2d = nn.ModuleList()
|
||||
for s in [16, 16]:
|
||||
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s),
|
||||
padding=(1, s // 2),
|
||||
stride=(1, s))
|
||||
conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s))
|
||||
conv_trans2d = nn.utils.weight_norm(conv_trans2d)
|
||||
nn.init.kaiming_normal_(conv_trans2d.weight)
|
||||
self.upsample_conv2d.append(conv_trans2d)
|
||||
@@ -152,7 +153,7 @@ class ResidualBlock(nn.Module):
|
||||
h += mel_spec
|
||||
|
||||
# Gated-tanh nonlinearity
|
||||
out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :])
|
||||
out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels :, :])
|
||||
|
||||
# Residual and skip outputs
|
||||
res = self.res_conv(out)
|
||||
@@ -164,10 +165,16 @@ class ResidualBlock(nn.Module):
|
||||
|
||||
|
||||
class ResidualGroup(nn.Module):
|
||||
def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle,
|
||||
diffusion_step_embed_dim_in,
|
||||
diffusion_step_embed_dim_mid,
|
||||
diffusion_step_embed_dim_out):
|
||||
def __init__(
|
||||
self,
|
||||
res_channels,
|
||||
skip_channels,
|
||||
num_res_layers,
|
||||
dilation_cycle,
|
||||
diffusion_step_embed_dim_in,
|
||||
diffusion_step_embed_dim_mid,
|
||||
diffusion_step_embed_dim_out,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_res_layers = num_res_layers
|
||||
self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
|
||||
@@ -180,16 +187,19 @@ class ResidualGroup(nn.Module):
|
||||
self.residual_blocks = nn.ModuleList()
|
||||
for n in range(self.num_res_layers):
|
||||
self.residual_blocks.append(
|
||||
ResidualBlock(res_channels, skip_channels,
|
||||
dilation=2 ** (n % dilation_cycle),
|
||||
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out))
|
||||
ResidualBlock(
|
||||
res_channels,
|
||||
skip_channels,
|
||||
dilation=2 ** (n % dilation_cycle),
|
||||
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, input_data):
|
||||
x, mel_spectrogram, diffusion_steps = input_data
|
||||
|
||||
# Embed diffusion step t
|
||||
diffusion_step_embed = calc_diffusion_step_embedding(
|
||||
diffusion_steps, self.diffusion_step_embed_dim_in)
|
||||
diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in)
|
||||
diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed))
|
||||
diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed))
|
||||
|
||||
@@ -206,27 +216,52 @@ class ResidualGroup(nn.Module):
|
||||
return skip * math.sqrt(1.0 / self.num_res_layers)
|
||||
|
||||
|
||||
class DiffWave(nn.Module):
|
||||
def __init__(self, in_channels, res_channels, skip_channels, out_channels,
|
||||
num_res_layers, dilation_cycle,
|
||||
diffusion_step_embed_dim_in,
|
||||
diffusion_step_embed_dim_mid,
|
||||
diffusion_step_embed_dim_out):
|
||||
class DiffWave(ModelMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
res_channels=128,
|
||||
skip_channels=128,
|
||||
out_channels=1,
|
||||
num_res_layers=30,
|
||||
dilation_cycle=10,
|
||||
diffusion_step_embed_dim_in=128,
|
||||
diffusion_step_embed_dim_mid=512,
|
||||
diffusion_step_embed_dim_out=512,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# register all init arguments with self.register
|
||||
self.register(
|
||||
in_channels=in_channels,
|
||||
res_channels=res_channels,
|
||||
skip_channels=skip_channels,
|
||||
out_channels=out_channels,
|
||||
num_res_layers=num_res_layers,
|
||||
dilation_cycle=dilation_cycle,
|
||||
diffusion_step_embed_dim_in=diffusion_step_embed_dim_in,
|
||||
diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid,
|
||||
diffusion_step_embed_dim_out=diffusion_step_embed_dim_out,
|
||||
)
|
||||
|
||||
# Initial conv1x1 with relu
|
||||
self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False))
|
||||
# All residual layers
|
||||
self.residual_layer = ResidualGroup(res_channels,
|
||||
skip_channels,
|
||||
num_res_layers,
|
||||
dilation_cycle,
|
||||
diffusion_step_embed_dim_in,
|
||||
diffusion_step_embed_dim_mid,
|
||||
diffusion_step_embed_dim_out)
|
||||
self.residual_layer = ResidualGroup(
|
||||
res_channels,
|
||||
skip_channels,
|
||||
num_res_layers,
|
||||
dilation_cycle,
|
||||
diffusion_step_embed_dim_in,
|
||||
diffusion_step_embed_dim_mid,
|
||||
diffusion_step_embed_dim_out,
|
||||
)
|
||||
# Final conv1x1 -> relu -> zeroconv1x1
|
||||
self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1),
|
||||
nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels))
|
||||
self.final_conv = nn.Sequential(
|
||||
Conv(skip_channels, skip_channels, kernel_size=1),
|
||||
nn.ReLU(inplace=False),
|
||||
ZeroConv1d(skip_channels, out_channels),
|
||||
)
|
||||
|
||||
def forward(self, input_data):
|
||||
audio, mel_spectrogram, diffusion_steps = input_data
|
||||
@@ -234,3 +269,45 @@ class DiffWave(nn.Module):
|
||||
x = self.init_conv(x).clone()
|
||||
x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
|
||||
return self.final_conv(x)
|
||||
|
||||
|
||||
class BDDM(DiffusionPipeline):
|
||||
def __init__(self, diffwave, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, mel_spectrogram, generator, torch_device=None):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.diffwave.to(torch_device)
|
||||
|
||||
mel_spectrogram = mel_spectrogram.to(torch_device)
|
||||
audio_length = mel_spectrogram.size(-1) * 256
|
||||
audio_size = (1, 1, audio_length)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
|
||||
|
||||
timestep_values = self.noise_scheduler.timestep_values
|
||||
num_prediction_steps = len(self.noise_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
# 1. predict noise residual
|
||||
ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
|
||||
residual = self.diffwave((audio, mel_spectrogram, ts))
|
||||
|
||||
# 2. predict previous mean of audio x_t-1
|
||||
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
|
||||
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
|
||||
|
||||
# 4. set current audio to prev_audio: x_t -> x_t-1
|
||||
audio = pred_prev_audio + variance
|
||||
|
||||
return audio
|
||||
|
||||
@@ -28,13 +28,7 @@ from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import (
|
||||
ModelOutput,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward,
|
||||
logging,
|
||||
replace_return_docstrings,
|
||||
)
|
||||
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
|
||||
|
||||
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -872,31 +866,31 @@ class GLIDE(DiffusionPipeline):
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
|
||||
(
|
||||
batch_size,
|
||||
self.upscale_unet.in_channels // 2,
|
||||
self.upscale_unet.resolution,
|
||||
self.upscale_unet.resolution,
|
||||
),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device)
|
||||
image = image.to(torch_device) * upsample_temp
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
|
||||
# adapt the beta schedule to the number of steps
|
||||
# self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_image -> f_theta(x_t, t) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
||||
time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device)
|
||||
model_output = self.upscale_unet(image, time_input, low_res)
|
||||
noise_residual, pred_variance = torch.split(model_output, 3, dim=1)
|
||||
|
||||
# 2. predict previous mean of image x_t-1
|
||||
pred_prev_image = self.upscale_noise_scheduler.step(
|
||||
noise_residual, image, t, num_inference_steps_upscale, eta
|
||||
noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True
|
||||
)
|
||||
|
||||
# 3. optionally sample variance
|
||||
@@ -910,6 +904,6 @@ class GLIDE(DiffusionPipeline):
|
||||
# 4. set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
image = image.clamp(-1, 1).permute(0, 2, 3, 1)
|
||||
|
||||
return image
|
||||
|
||||
@@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
trained_betas=None,
|
||||
timestep_values=None,
|
||||
clip_predicted_image=True,
|
||||
tensor_format="np",
|
||||
):
|
||||
@@ -37,6 +39,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_schedule=beta_schedule,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
|
||||
self.clip_image = clip_predicted_image
|
||||
|
||||
if beta_schedule == "linear":
|
||||
@@ -69,14 +72,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
#
|
||||
# self.register_buffer("log_variance", log_variance.to(torch.float32))
|
||||
|
||||
def rescale_betas(self, num_timesteps):
|
||||
if self.beta_schedule == "linear":
|
||||
scale = self.timesteps / num_timesteps
|
||||
self.betas = linear_beta_schedule(
|
||||
num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
|
||||
)
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
# def rescale_betas(self, num_timesteps):
|
||||
# # GLIDE scaling
|
||||
# if self.beta_schedule == "linear":
|
||||
# scale = self.timesteps / num_timesteps
|
||||
# self.betas = linear_beta_schedule(
|
||||
# num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale
|
||||
# )
|
||||
# self.alphas = 1.0 - self.betas
|
||||
# self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
|
||||
def get_alpha(self, time_step):
|
||||
return self.alphas[time_step]
|
||||
@@ -107,7 +111,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def step(self, residual, image, t, num_inference_steps, eta):
|
||||
def step(self, residual, image, t, num_inference_steps, eta, use_clipped_residual=False):
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
@@ -141,6 +145,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance = self.get_variance(t, num_inference_steps)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
if use_clipped_residual:
|
||||
# the residual is always re-derived from the clipped x_0 in GLIDE
|
||||
residual = (image - alpha_prod_t ** (0.5) * pred_original_image) / beta_prod_t ** (0.5)
|
||||
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual
|
||||
|
||||
|
||||
@@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
trained_betas=None,
|
||||
timestep_values=None,
|
||||
variance_type="fixed_small",
|
||||
clip_predicted_image=True,
|
||||
tensor_format="np",
|
||||
@@ -36,14 +38,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
trained_betas=trained_betas,
|
||||
timestep_values=timestep_values,
|
||||
variance_type=variance_type,
|
||||
clip_predicted_image=clip_predicted_image,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
|
||||
self.clip_image = clip_predicted_image
|
||||
self.variance_type = variance_type
|
||||
|
||||
if beta_schedule == "linear":
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
@@ -56,6 +63,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
||||
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1 - self.alphas_cumprod)
|
||||
self.one = np.array(1.0)
|
||||
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
@@ -131,5 +140,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return pred_prev_image
|
||||
|
||||
def forward_step(self, original_image, noise, t):
|
||||
noisy_image = self.sqrt_alphas_cumprod[t] * original_image + self.sqrt_one_minus_alphas_cumprod[t] * noise
|
||||
return noisy_image
|
||||
|
||||
def __len__(self):
|
||||
return self.timesteps
|
||||
|
||||
116
src/diffusers/trainers/training_ddpm.py
Normal file
116
src/diffusers/trainers/training_ddpm.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import PIL.Image
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPM, DDPMScheduler, UNetModel
|
||||
from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
set_seed(0)
|
||||
|
||||
accelerator = Accelerator(mixed_precision="fp16")
|
||||
|
||||
model = UNetModel(ch=128, ch_mult=(1, 2, 4, 8), resolution=64)
|
||||
noise_scheduler = DDPMScheduler(timesteps=1000)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
|
||||
num_epochs = 100
|
||||
batch_size = 8
|
||||
gradient_accumulation_steps = 8
|
||||
|
||||
augmentations = Compose(
|
||||
[
|
||||
Resize(64),
|
||||
CenterCrop(64),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
Lambda(lambda x: x * 2 - 1),
|
||||
]
|
||||
)
|
||||
dataset = load_dataset("huggan/pokemon", split="train")
|
||||
|
||||
|
||||
def transforms(examples):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"input": images}
|
||||
|
||||
|
||||
dataset = dataset.shuffle(seed=0)
|
||||
dataset.set_transform(transforms)
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=1000,
|
||||
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
pbar = tqdm(total=len(train_dataloader), unit="ba")
|
||||
pbar.set_description(f"Epoch {epoch}")
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
clean_images = batch["input"]
|
||||
noisy_images = torch.empty_like(clean_images)
|
||||
bsz = clean_images.shape[0]
|
||||
|
||||
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
|
||||
for idx in range(bsz):
|
||||
noise = torch.randn_like(clean_images[0]).to(clean_images.device)
|
||||
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
|
||||
|
||||
if step % gradient_accumulation_steps == 0:
|
||||
with accelerator.no_sync(model):
|
||||
output = model(noisy_images, timesteps)
|
||||
loss = F.l1_loss(output, clean_images)
|
||||
accelerator.backward(loss)
|
||||
else:
|
||||
output = model(noisy_images, timesteps)
|
||||
loss = F.l1_loss(output, clean_images)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
|
||||
|
||||
optimizer.step()
|
||||
|
||||
# eval
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
image = pipeline(generator=generator)
|
||||
|
||||
# process image to PIL
|
||||
image_processed = image.cpu().permute(0, 2, 3, 1)
|
||||
image_processed = (image_processed + 1.0) * 127.5
|
||||
image_processed = image_processed.type(torch.uint8).numpy()
|
||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# save image
|
||||
pipeline.save_pretrained("./poke-ddpm")
|
||||
image_pil.save(f"./poke-ddpm/test_{epoch}.png")
|
||||
Reference in New Issue
Block a user