mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add BDDMPipeline
This commit is contained in:
@@ -17,6 +17,9 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
|
||||
@@ -234,3 +237,45 @@ class DiffWave(nn.Module):
|
||||
x = self.init_conv(x).clone()
|
||||
x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
|
||||
return self.final_conv(x)
|
||||
|
||||
|
||||
class BDDMPipeline(DiffusionPipeline):
|
||||
def __init__(self, diffwave, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, mel_spectrogram, generator):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.diffwave.to(torch_device)
|
||||
|
||||
audio_length = mel_spectrogram.size(-1) * self.config.hop_len
|
||||
audio_size = (1, 1, audio_length)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
|
||||
|
||||
timestep_values = self.noise_scheduler.timestep_values
|
||||
num_prediction_steps = len(self.noise_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
t = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
|
||||
residual = self.diffwave(audio, mel_spectrogram, t)
|
||||
|
||||
# 2. predict previous mean of audio x_t-1
|
||||
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
|
||||
|
||||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if t > 0:
|
||||
noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
|
||||
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
|
||||
|
||||
# 4. set current audio to prev_audio: x_t -> x_t-1
|
||||
audio = pred_prev_audio + variance
|
||||
|
||||
return audio
|
||||
@@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
trained_betas=None,
|
||||
timestep_values=None,
|
||||
clip_predicted_image=True,
|
||||
tensor_format="np",
|
||||
):
|
||||
@@ -37,9 +39,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_schedule=beta_schedule,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
|
||||
self.clip_image = clip_predicted_image
|
||||
|
||||
if beta_schedule == "linear":
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
|
||||
Reference in New Issue
Block a user