diff --git a/src/diffusers/pipelines/pipeline_bddm.py b/src/diffusers/pipelines/pipeline_bddm.py index dd2753cbec..85dddff33e 100644 --- a/src/diffusers/pipelines/pipeline_bddm.py +++ b/src/diffusers/pipelines/pipeline_bddm.py @@ -17,6 +17,9 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import tqdm + +from ..pipeline_utils import DiffusionPipeline def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): @@ -234,3 +237,45 @@ class DiffWave(nn.Module): x = self.init_conv(x).clone() x = self.residual_layer((x, mel_spectrogram, diffusion_steps)) return self.final_conv(x) + + +class BDDMPipeline(DiffusionPipeline): + def __init__(self, diffwave, noise_scheduler): + super().__init__() + noise_scheduler = noise_scheduler.set_format("pt") + self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler) + + @torch.no_grad() + def __call__(self, mel_spectrogram, generator): + if torch_device is None: + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + self.diffwave.to(torch_device) + + audio_length = mel_spectrogram.size(-1) * self.config.hop_len + audio_size = (1, 1, audio_length) + + # Sample gaussian noise to begin loop + audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device) + + timestep_values = self.noise_scheduler.timestep_values + num_prediction_steps = len(self.noise_scheduler) + for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): + # 1. predict noise residual + with torch.no_grad(): + t = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device) + residual = self.diffwave(audio, mel_spectrogram, t) + + # 2. predict previous mean of audio x_t-1 + pred_prev_audio = self.noise_scheduler.step(residual, audio, t) + + # 3. optionally sample variance + variance = 0 + if t > 0: + noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device) + variance = self.noise_scheduler.get_variance(t).sqrt() * noise + + # 4. set current audio to prev_audio: x_t -> x_t-1 + audio = pred_prev_audio + variance + + return audio \ No newline at end of file diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 883a358d34..842348d106 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): beta_start=0.0001, beta_end=0.02, beta_schedule="linear", + trained_betas=None, + timestep_values=None, clip_predicted_image=True, tensor_format="np", ): @@ -37,9 +39,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): beta_schedule=beta_schedule, ) self.timesteps = int(timesteps) + self.timestep_values = timestep_values # save the fixed timestep values for BDDM self.clip_image = clip_predicted_image - if beta_schedule == "linear": + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + elif beta_schedule == "linear": self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) elif beta_schedule == "squaredcos_cap_v2": # GLIDE cosine schedule