1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge remote-tracking branch 'origin/main'

# Conflicts:
#	src/diffusers/__init__.py
#	src/diffusers/pipelines/__init__.py
#	src/diffusers/schedulers/scheduling_ddim.py
This commit is contained in:
anton-l
2022-06-13 16:52:12 +02:00
5 changed files with 18 additions and 13 deletions

View File

@@ -9,6 +9,6 @@ from .models.unet import UNetModel
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIM, DDPM, GLIDE, BDDMPipeline, LatentDiffusion
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, BDDM
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler

View File

@@ -1,5 +1,5 @@
from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_glide import GLIDE
from .pipeline_latent_diffusion import LatentDiffusion
from .pipeline_bddm import BDDM

View File

@@ -271,20 +271,21 @@ class DiffWave(ModelMixin, ConfigMixin):
return self.final_conv(x)
class BDDMPipeline(DiffusionPipeline):
class BDDM(DiffusionPipeline):
def __init__(self, diffwave, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
@torch.no_grad()
def __call__(self, mel_spectrogram, generator):
def __call__(self, mel_spectrogram, generator, torch_device=None):
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.diffwave.to(torch_device)
audio_length = mel_spectrogram.size(-1) * self.config.hop_len
mel_spectrogram = mel_spectrogram.to(torch_device)
audio_length = mel_spectrogram.size(-1) * 256
audio_size = (1, 1, audio_length)
# Sample gaussian noise to begin loop
@@ -294,9 +295,8 @@ class BDDMPipeline(DiffusionPipeline):
num_prediction_steps = len(self.noise_scheduler)
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
# 1. predict noise residual
with torch.no_grad():
t = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
residual = self.diffwave(audio, mel_spectrogram, t)
ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
residual = self.diffwave((audio, mel_spectrogram, ts))
# 2. predict previous mean of audio x_t-1
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)

View File

@@ -42,9 +42,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
if beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule

View File

@@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
variance_type="fixed_small",
clip_predicted_image=True,
tensor_format="np",
@@ -36,14 +38,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
trained_betas=trained_betas,
timestep_values=timestep_values,
variance_type=variance_type,
clip_predicted_image=clip_predicted_image,
)
self.timesteps = int(timesteps)
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
self.clip_image = clip_predicted_image
self.variance_type = variance_type
if beta_schedule == "linear":
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
elif beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule