diff --git a/README.md b/README.md index 10d3655b21..7c0a2fe71f 100644 --- a/README.md +++ b/README.md @@ -164,7 +164,7 @@ image_pil = PIL.Image.fromarray(image_processed[0]) image_pil.save("test.png") ``` -**Text to Image generation with Latent Diffusion** +#### **Text to Image generation with Latent Diffusion** ```python from diffusers import DiffusionPipeline @@ -184,59 +184,98 @@ image_pil = PIL.Image.fromarray(image_processed[0]) # save image image_pil.save("test.png") +``` + + #### **Text to speech with BDDM** + +_Follow the isnstructions [here](https://pytorch.org/hub/nvidia_deeplearningexamples_tacotron2/) to load tacotron2 model._ + +```python +import torch +from diffusers import BDDM, DiffusionPipeline + +torch_device = "cuda" + +# load the BDDM pipeline +bddm = DiffusionPipeline.from_pretrained("fusing/diffwave-vocoder") + +# load tacotron2 to get the mel spectograms +tacotron2 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tacotron2', model_math='fp16') +tacotron2 = tacotron2.to(torch_device).eval() + +text = "Hello world, I missed you so much." + +utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_tts_utils') +sequences, lengths = utils.prepare_input_sequence([text]) + +# generate mel spectograms using text +with torch.no_grad(): + mel_spec, _, _ = tacotron2.infer(sequences, lengths) + +# generate the speech by passing mel spectograms to BDDM pipeline +generator = torch.manual_seed(0) +audio = bddm(mel_spec, generator, torch_device) + +# save generated audio +from scipy.io.wavfile import write as wavwrite +sampling_rate = 22050 +wavwrite("generated_audio.wav", sampling_rate, audio.squeeze().cpu().numpy()) ``` ## Library structure: ``` -├── models -│   ├── audio -│   │   └── fastdiff -│   │   ├── modeling_fastdiff.py -│   │   ├── README.md -│   │   └── run_fastdiff.py -│   ├── __init__.py -│   └── vision -│   ├── dalle2 -│   │   ├── modeling_dalle2.py -│   │   ├── README.md -│   │   └── run_dalle2.py -│   ├── ddpm -│   │   ├── example.py -│   │   ├── modeling_ddpm.py -│   │   ├── README.md -│   │   └── run_ddpm.py -│   ├── glide -│   │   ├── modeling_glide.py -│   │   ├── modeling_vqvae.py.py -│   │   ├── README.md -│   │   └── run_glide.py -│   ├── imagen -│   │   ├── modeling_dalle2.py -│   │   ├── README.md -│   │   └── run_dalle2.py -│   ├── __init__.py -│   └── latent_diffusion -│   ├── modeling_latent_diffusion.py -│   ├── README.md -│   └── run_latent_diffusion.py -├── pyproject.toml +├── LICENSE +├── Makefile ├── README.md +├── pyproject.toml ├── setup.cfg ├── setup.py ├── src -│   └── diffusers -│   ├── configuration_utils.py -│   ├── __init__.py -│   ├── modeling_utils.py -│   ├── models -│   │   ├── __init__.py -│   │   ├── unet_glide.py -│   │   └── unet.py -│   ├── pipeline_utils.py -│   └── schedulers -│   ├── gaussian_ddpm.py -│   ├── __init__.py +│ ├── diffusers +│ ├── __init__.py +│ ├── configuration_utils.py +│ ├── dependency_versions_check.py +│ ├── dependency_versions_table.py +│ ├── dynamic_modules_utils.py +│ ├── modeling_utils.py +│ ├── models +│ │ ├── __init__.py +│ │ ├── unet.py +│ │ ├── unet_glide.py +│ │ └── unet_ldm.py +│ ├── pipeline_utils.py +│ ├── pipelines +│ │ ├── __init__.py +│ │ ├── configuration_ldmbert.py +│ │ ├── conversion_glide.py +│ │ ├── modeling_vae.py +│ │ ├── pipeline_bddm.py +│ │ ├── pipeline_ddim.py +│ │ ├── pipeline_ddpm.py +│ │ ├── pipeline_glide.py +│ │ └── pipeline_latent_diffusion.py +│ ├── schedulers +│ │ ├── __init__.py +│ │ ├── classifier_free_guidance.py +│ │ ├── scheduling_ddim.py +│ │ ├── scheduling_ddpm.py +│ │ ├── scheduling_plms.py +│ │ └── scheduling_utils.py +│ ├── testing_utils.py +│ └── utils +│ ├── __init__.py +│ └── logging.py ├── tests -│   └── test_modeling_utils.py +│ ├── __init__.py +│ ├── test_modeling_utils.py +│ └── test_scheduler.py +└── utils + ├── check_config_docstrings.py + ├── check_copies.py + ├── check_dummies.py + ├── check_inits.py + ├── check_repo.py + ├── check_table.py + └── check_tf_ops.py ``` diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index a3b9f1021a..d46c6b0902 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,6 +9,6 @@ from .models.unet import UNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_ldm import UNetLDMModel from .pipeline_utils import DiffusionPipeline -from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM +from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index ce0b7d0ea5..4436445334 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -225,11 +225,11 @@ class ConfigMixin: text = reader.read() return json.loads(text) - def __eq__(self, other): - return self.__dict__ == other.__dict__ + # def __eq__(self, other): + # return self.__dict__ == other.__dict__ - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" + # def __repr__(self): + # return f"{self.__class__.__name__} {self.to_json_string()}" @property def config(self) -> Dict[str, Any]: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 0ea00d8763..e0d2bf2e30 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -3,3 +3,4 @@ from .pipeline_ddpm import DDPM from .pipeline_pndm import PNDM from .pipeline_glide import GLIDE from .pipeline_latent_diffusion import LatentDiffusion +from .pipeline_bddm import BDDM diff --git a/src/diffusers/pipelines/conversion_glide.py b/src/diffusers/pipelines/conversion_glide.py index 499c071204..2d04580e76 100644 --- a/src/diffusers/pipelines/conversion_glide.py +++ b/src/diffusers/pipelines/conversion_glide.py @@ -97,7 +97,9 @@ superres_model = GLIDESuperResUNetModel( superres_model.load_state_dict(ups_state_dict, strict=False) -upscale_scheduler = DDIMScheduler(timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02) +upscale_scheduler = DDIMScheduler( + timesteps=1000, beta_schedule="linear", beta_start=0.0001, beta_end=0.02, tensor_format="pt" +) glide = GLIDE( text_unet=text2im_model, diff --git a/src/diffusers/pipelines/pipeline_bddm.py b/src/diffusers/pipelines/pipeline_bddm.py index dd2753cbec..ee9e628f4d 100644 --- a/src/diffusers/pipelines/pipeline_bddm.py +++ b/src/diffusers/pipelines/pipeline_bddm.py @@ -13,11 +13,18 @@ import math + import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import tqdm + +from ..configuration_utils import ConfigMixin +from ..modeling_utils import ModelMixin +from ..pipeline_utils import DiffusionPipeline + def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): """ @@ -41,8 +48,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): _embed = np.log(10000) / (half_dim - 1) _embed = torch.exp(torch.arange(half_dim) * -_embed).cuda() _embed = diffusion_steps * _embed - diffusion_step_embed = torch.cat((torch.sin(_embed), - torch.cos(_embed)), 1) + diffusion_step_embed = torch.cat((torch.sin(_embed), torch.cos(_embed)), 1) return diffusion_step_embed @@ -62,8 +68,7 @@ class Conv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1): super().__init__() self.padding = dilation * (kernel_size - 1) // 2 - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, - dilation=dilation, padding=self.padding) + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding) self.conv = nn.utils.weight_norm(self.conv) nn.init.kaiming_normal_(self.conv.weight) @@ -89,8 +94,7 @@ class ZeroConv1d(nn.Module): # every residual block (named residual layer in paper) # contains one noncausal dilated conv class ResidualBlock(nn.Module): - def __init__(self, res_channels, skip_channels, dilation, - diffusion_step_embed_dim_out): + def __init__(self, res_channels, skip_channels, dilation, diffusion_step_embed_dim_out): super().__init__() self.res_channels = res_channels @@ -98,15 +102,12 @@ class ResidualBlock(nn.Module): self.fc_t = nn.Linear(diffusion_step_embed_dim_out, self.res_channels) # Dilated conv layer - self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, - kernel_size=3, dilation=dilation) + self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation) # Add mel spectrogram upsampler and conditioner conv1x1 layer self.upsample_conv2d = nn.ModuleList() for s in [16, 16]: - conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s), - padding=(1, s // 2), - stride=(1, s)) + conv_trans2d = nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s)) conv_trans2d = nn.utils.weight_norm(conv_trans2d) nn.init.kaiming_normal_(conv_trans2d.weight) self.upsample_conv2d.append(conv_trans2d) @@ -152,7 +153,7 @@ class ResidualBlock(nn.Module): h += mel_spec # Gated-tanh nonlinearity - out = torch.tanh(h[:, :self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels:, :]) + out = torch.tanh(h[:, : self.res_channels, :]) * torch.sigmoid(h[:, self.res_channels :, :]) # Residual and skip outputs res = self.res_conv(out) @@ -164,10 +165,16 @@ class ResidualBlock(nn.Module): class ResidualGroup(nn.Module): - def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle, - diffusion_step_embed_dim_in, - diffusion_step_embed_dim_mid, - diffusion_step_embed_dim_out): + def __init__( + self, + res_channels, + skip_channels, + num_res_layers, + dilation_cycle, + diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out, + ): super().__init__() self.num_res_layers = num_res_layers self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in @@ -180,16 +187,19 @@ class ResidualGroup(nn.Module): self.residual_blocks = nn.ModuleList() for n in range(self.num_res_layers): self.residual_blocks.append( - ResidualBlock(res_channels, skip_channels, - dilation=2 ** (n % dilation_cycle), - diffusion_step_embed_dim_out=diffusion_step_embed_dim_out)) + ResidualBlock( + res_channels, + skip_channels, + dilation=2 ** (n % dilation_cycle), + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + ) + ) def forward(self, input_data): x, mel_spectrogram, diffusion_steps = input_data # Embed diffusion step t - diffusion_step_embed = calc_diffusion_step_embedding( - diffusion_steps, self.diffusion_step_embed_dim_in) + diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in) diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) @@ -206,27 +216,52 @@ class ResidualGroup(nn.Module): return skip * math.sqrt(1.0 / self.num_res_layers) -class DiffWave(nn.Module): - def __init__(self, in_channels, res_channels, skip_channels, out_channels, - num_res_layers, dilation_cycle, - diffusion_step_embed_dim_in, - diffusion_step_embed_dim_mid, - diffusion_step_embed_dim_out): +class DiffWave(ModelMixin, ConfigMixin): + def __init__( + self, + in_channels=1, + res_channels=128, + skip_channels=128, + out_channels=1, + num_res_layers=30, + dilation_cycle=10, + diffusion_step_embed_dim_in=128, + diffusion_step_embed_dim_mid=512, + diffusion_step_embed_dim_out=512, + ): super().__init__() + # register all init arguments with self.register + self.register( + in_channels=in_channels, + res_channels=res_channels, + skip_channels=skip_channels, + out_channels=out_channels, + num_res_layers=num_res_layers, + dilation_cycle=dilation_cycle, + diffusion_step_embed_dim_in=diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid=diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out=diffusion_step_embed_dim_out, + ) + # Initial conv1x1 with relu self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU(inplace=False)) # All residual layers - self.residual_layer = ResidualGroup(res_channels, - skip_channels, - num_res_layers, - dilation_cycle, - diffusion_step_embed_dim_in, - diffusion_step_embed_dim_mid, - diffusion_step_embed_dim_out) + self.residual_layer = ResidualGroup( + res_channels, + skip_channels, + num_res_layers, + dilation_cycle, + diffusion_step_embed_dim_in, + diffusion_step_embed_dim_mid, + diffusion_step_embed_dim_out, + ) # Final conv1x1 -> relu -> zeroconv1x1 - self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1), - nn.ReLU(inplace=False), ZeroConv1d(skip_channels, out_channels)) + self.final_conv = nn.Sequential( + Conv(skip_channels, skip_channels, kernel_size=1), + nn.ReLU(inplace=False), + ZeroConv1d(skip_channels, out_channels), + ) def forward(self, input_data): audio, mel_spectrogram, diffusion_steps = input_data @@ -234,3 +269,45 @@ class DiffWave(nn.Module): x = self.init_conv(x).clone() x = self.residual_layer((x, mel_spectrogram, diffusion_steps)) return self.final_conv(x) + + +class BDDM(DiffusionPipeline): + def __init__(self, diffwave, noise_scheduler): + super().__init__() + noise_scheduler = noise_scheduler.set_format("pt") + self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler) + + @torch.no_grad() + def __call__(self, mel_spectrogram, generator, torch_device=None): + if torch_device is None: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + self.diffwave.to(torch_device) + + mel_spectrogram = mel_spectrogram.to(torch_device) + audio_length = mel_spectrogram.size(-1) * 256 + audio_size = (1, 1, audio_length) + + # Sample gaussian noise to begin loop + audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device) + + timestep_values = self.noise_scheduler.timestep_values + num_prediction_steps = len(self.noise_scheduler) + for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): + # 1. predict noise residual + ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device) + residual = self.diffwave((audio, mel_spectrogram, ts)) + + # 2. predict previous mean of audio x_t-1 + pred_prev_audio = self.noise_scheduler.step(residual, audio, t) + + # 3. optionally sample variance + variance = 0 + if t > 0: + noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device) + variance = self.noise_scheduler.get_variance(t).sqrt() * noise + + # 4. set current audio to prev_audio: x_t -> x_t-1 + audio = pred_prev_audio + variance + + return audio diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 9a45d492b9..1f6495e890 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -28,13 +28,7 @@ from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - ModelOutput, - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) +from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..pipeline_utils import DiffusionPipeline @@ -872,31 +866,31 @@ class GLIDE(DiffusionPipeline): # Sample gaussian noise to begin loop image = torch.randn( - (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), + ( + batch_size, + self.upscale_unet.in_channels // 2, + self.upscale_unet.resolution, + self.upscale_unet.resolution, + ), generator=generator, ) - image = image.to(torch_device) + image = image.to(torch_device) * upsample_temp - # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf - # Ideally, read DDIM paper in-detail understanding + num_trained_timesteps = self.upscale_noise_scheduler.timesteps + inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale) + # adapt the beta schedule to the number of steps + # self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale) - # Notation ( -> - # - pred_noise_t -> e_theta(x_t, t) - # - pred_original_image -> f_theta(x_t, t) or x_0 - # - std_dev_t -> sigma_t - # - eta -> η - # - pred_image_direction -> "direction pointingc to x_t" - # - pred_prev_image -> "x_t-1" for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale): # 1. predict noise residual with torch.no_grad(): - time_input = torch.tensor([t] * image.shape[0], device=torch_device) + time_input = torch.tensor([inference_step_times[t]] * image.shape[0], device=torch_device) model_output = self.upscale_unet(image, time_input, low_res) noise_residual, pred_variance = torch.split(model_output, 3, dim=1) # 2. predict previous mean of image x_t-1 pred_prev_image = self.upscale_noise_scheduler.step( - noise_residual, image, t, num_inference_steps_upscale, eta + noise_residual, image, t, num_inference_steps_upscale, eta, use_clipped_residual=True ) # 3. optionally sample variance @@ -910,6 +904,6 @@ class GLIDE(DiffusionPipeline): # 4. set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + variance - image = image.permute(0, 2, 3, 1) + image = image.clamp(-1, 1).permute(0, 2, 3, 1) return image diff --git a/src/diffusers/schedulers/glide_ddim.py b/src/diffusers/schedulers/glide_ddim.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 883a358d34..88e4725e75 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): beta_start=0.0001, beta_end=0.02, beta_schedule="linear", + trained_betas=None, + timestep_values=None, clip_predicted_image=True, tensor_format="np", ): @@ -37,6 +39,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): beta_schedule=beta_schedule, ) self.timesteps = int(timesteps) + self.timestep_values = timestep_values # save the fixed timestep values for BDDM self.clip_image = clip_predicted_image if beta_schedule == "linear": @@ -69,14 +72,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): # # self.register_buffer("log_variance", log_variance.to(torch.float32)) - def rescale_betas(self, num_timesteps): - if self.beta_schedule == "linear": - scale = self.timesteps / num_timesteps - self.betas = linear_beta_schedule( - num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale - ) - self.alphas = 1.0 - self.betas - self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + # def rescale_betas(self, num_timesteps): + # # GLIDE scaling + # if self.beta_schedule == "linear": + # scale = self.timesteps / num_timesteps + # self.betas = linear_beta_schedule( + # num_timesteps, beta_start=self.beta_start * scale, beta_end=self.beta_end * scale + # ) + # self.alphas = 1.0 - self.betas + # self.alphas_cumprod = np.cumprod(self.alphas, axis=0) def get_alpha(self, time_step): return self.alphas[time_step] @@ -107,7 +111,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): return variance - def step(self, residual, image, t, num_inference_steps, eta): + def step(self, residual, image, t, num_inference_steps, eta, use_clipped_residual=False): # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf # Ideally, read DDIM paper in-detail understanding @@ -141,6 +145,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): variance = self.get_variance(t, num_inference_steps) std_dev_t = eta * variance ** (0.5) + if use_clipped_residual: + # the residual is always re-derived from the clipped x_0 in GLIDE + residual = (image - alpha_prod_t ** (0.5) * pred_original_image) / beta_prod_t ** (0.5) + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * residual diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 22c5da63cc..d5a686b91f 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): beta_start=0.0001, beta_end=0.02, beta_schedule="linear", + trained_betas=None, + timestep_values=None, variance_type="fixed_small", clip_predicted_image=True, tensor_format="np", @@ -36,14 +38,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule, + trained_betas=trained_betas, + timestep_values=timestep_values, variance_type=variance_type, clip_predicted_image=clip_predicted_image, ) self.timesteps = int(timesteps) + self.timestep_values = timestep_values # save the fixed timestep values for BDDM self.clip_image = clip_predicted_image self.variance_type = variance_type - if beta_schedule == "linear": + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + elif beta_schedule == "linear": self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) elif beta_schedule == "squaredcos_cap_v2": # GLIDE cosine schedule @@ -56,6 +63,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): self.alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(self.alphas, axis=0) + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1 - self.alphas_cumprod) self.one = np.array(1.0) self.set_format(tensor_format=tensor_format) @@ -131,5 +140,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): return pred_prev_image + def forward_step(self, original_image, noise, t): + noisy_image = self.sqrt_alphas_cumprod[t] * original_image + self.sqrt_one_minus_alphas_cumprod[t] * noise + return noisy_image + def __len__(self): return self.timesteps diff --git a/src/diffusers/trainers/training_ddpm.py b/src/diffusers/trainers/training_ddpm.py new file mode 100644 index 0000000000..7ac9e52ab0 --- /dev/null +++ b/src/diffusers/trainers/training_ddpm.py @@ -0,0 +1,116 @@ +import random + +import numpy as np +import torch +import torch.nn.functional as F + +import PIL.Image +from accelerate import Accelerator +from datasets import load_dataset +from diffusers import DDPM, DDPMScheduler, UNetModel +from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor +from tqdm.auto import tqdm +from transformers import get_linear_schedule_with_warmup + + +def set_seed(seed): + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +set_seed(0) + +accelerator = Accelerator(mixed_precision="fp16") + +model = UNetModel(ch=128, ch_mult=(1, 2, 4, 8), resolution=64) +noise_scheduler = DDPMScheduler(timesteps=1000) +optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4) + +num_epochs = 100 +batch_size = 8 +gradient_accumulation_steps = 8 + +augmentations = Compose( + [ + Resize(64), + CenterCrop(64), + RandomHorizontalFlip(), + ToTensor(), + Lambda(lambda x: x * 2 - 1), + ] +) +dataset = load_dataset("huggan/pokemon", split="train") + + +def transforms(examples): + images = [augmentations(image.convert("RGB")) for image in examples["image"]] + return {"input": images} + + +dataset = dataset.shuffle(seed=0) +dataset.set_transform(transforms) +train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) + +lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=1000, + num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, +) + +model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, lr_scheduler +) + +for epoch in range(num_epochs): + model.train() + pbar = tqdm(total=len(train_dataloader), unit="ba") + pbar.set_description(f"Epoch {epoch}") + for step, batch in enumerate(train_dataloader): + clean_images = batch["input"] + noisy_images = torch.empty_like(clean_images) + bsz = clean_images.shape[0] + + timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() + for idx in range(bsz): + noise = torch.randn_like(clean_images[0]).to(clean_images.device) + noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx]) + + if step % gradient_accumulation_steps == 0: + with accelerator.no_sync(model): + output = model(noisy_images, timesteps) + loss = F.l1_loss(output, clean_images) + accelerator.backward(loss) + else: + output = model(noisy_images, timesteps) + loss = F.l1_loss(output, clean_images) + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + pbar.update(1) + pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) + + optimizer.step() + + # eval + model.eval() + with torch.no_grad(): + pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler) + generator = torch.Generator() + generator = generator.manual_seed(0) + # run pipeline in inference (sample random noise and denoise) + image = pipeline(generator=generator) + + # process image to PIL + image_processed = image.cpu().permute(0, 2, 3, 1) + image_processed = (image_processed + 1.0) * 127.5 + image_processed = image_processed.type(torch.uint8).numpy() + image_pil = PIL.Image.fromarray(image_processed[0]) + + # save image + pipeline.save_pretrained("./poke-ddpm") + image_pil.save(f"./poke-ddpm/test_{epoch}.png")