1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge remote-tracking branch 'origin/main'

This commit is contained in:
anton-l
2022-06-13 14:33:56 +02:00
4 changed files with 53 additions and 2 deletions

View File

@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, BDDMPipeline
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler

View File

@@ -2,3 +2,4 @@ from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_bddm import BDDMPipeline

View File

@@ -17,6 +17,9 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from ..pipeline_utils import DiffusionPipeline
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
@@ -234,3 +237,45 @@ class DiffWave(nn.Module):
x = self.init_conv(x).clone()
x = self.residual_layer((x, mel_spectrogram, diffusion_steps))
return self.final_conv(x)
class BDDMPipeline(DiffusionPipeline):
def __init__(self, diffwave, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(self, mel_spectrogram, generator):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.diffwave.to(torch_device)
audio_length = mel_spectrogram.size(-1) * self.config.hop_len
audio_size = (1, 1, audio_length)
# Sample gaussian noise to begin loop
audio = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
timestep_values = self.noise_scheduler.timestep_values
num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual
with torch.no_grad():
t = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
residual = self.diffwave(audio, mel_spectrogram, t)
# 2. predict previous mean of audio x_t-1
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
# 3. optionally sample variance
variance = 0
if t > 0:
noise = torch.normal(0, 1, size=audio_size, generator=generator).to(torch_device)
variance = self.noise_scheduler.get_variance(t).sqrt() * noise
# 4. set current audio to prev_audio: x_t -> x_t-1
audio = pred_prev_audio + variance
return audio

View File

@@ -26,6 +26,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
clip_predicted_image=True,
tensor_format="np",
):
@@ -37,9 +39,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule=beta_schedule,
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image
if beta_schedule == "linear":
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule