mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge remote-tracking branch 'origin/main'
# Conflicts: # src/diffusers/__init__.py # src/diffusers/pipelines/__init__.py # src/diffusers/schedulers/scheduling_ddim.py
This commit is contained in:
@@ -9,6 +9,6 @@ from .models.unet import UNetModel
|
||||
from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import DDIM, DDPM, GLIDE, BDDMPipeline, LatentDiffusion
|
||||
from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, BDDM
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin
|
||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .pipeline_bddm import BDDMPipeline
|
||||
from .pipeline_ddim import DDIM
|
||||
from .pipeline_ddpm import DDPM
|
||||
from .pipeline_glide import GLIDE
|
||||
from .pipeline_latent_diffusion import LatentDiffusion
|
||||
from .pipeline_bddm import BDDM
|
||||
|
||||
@@ -271,20 +271,21 @@ class DiffWave(ModelMixin, ConfigMixin):
|
||||
return self.final_conv(x)
|
||||
|
||||
|
||||
class BDDMPipeline(DiffusionPipeline):
|
||||
class BDDM(DiffusionPipeline):
|
||||
def __init__(self, diffwave, noise_scheduler):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, mel_spectrogram, generator):
|
||||
def __call__(self, mel_spectrogram, generator, torch_device=None):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.diffwave.to(torch_device)
|
||||
|
||||
audio_length = mel_spectrogram.size(-1) * self.config.hop_len
|
||||
|
||||
mel_spectrogram = mel_spectrogram.to(torch_device)
|
||||
audio_length = mel_spectrogram.size(-1) * 256
|
||||
audio_size = (1, 1, audio_length)
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
@@ -294,9 +295,8 @@ class BDDMPipeline(DiffusionPipeline):
|
||||
num_prediction_steps = len(self.noise_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
# 1. predict noise residual
|
||||
with torch.no_grad():
|
||||
t = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
|
||||
residual = self.diffwave(audio, mel_spectrogram, t)
|
||||
ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device)
|
||||
residual = self.diffwave((audio, mel_spectrogram, ts))
|
||||
|
||||
# 2. predict previous mean of audio x_t-1
|
||||
pred_prev_audio = self.noise_scheduler.step(residual, audio, t)
|
||||
|
||||
@@ -42,9 +42,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
|
||||
self.clip_image = clip_predicted_image
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
if beta_schedule == "linear":
|
||||
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
|
||||
@@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
trained_betas=None,
|
||||
timestep_values=None,
|
||||
variance_type="fixed_small",
|
||||
clip_predicted_image=True,
|
||||
tensor_format="np",
|
||||
@@ -36,14 +38,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
beta_schedule=beta_schedule,
|
||||
trained_betas=trained_betas,
|
||||
timestep_values=timestep_values,
|
||||
variance_type=variance_type,
|
||||
clip_predicted_image=clip_predicted_image,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
self.timestep_values = timestep_values # save the fixed timestep values for BDDM
|
||||
self.clip_image = clip_predicted_image
|
||||
self.variance_type = variance_type
|
||||
|
||||
if beta_schedule == "linear":
|
||||
if trained_betas is not None:
|
||||
self.betas = np.asarray(trained_betas)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
|
||||
Reference in New Issue
Block a user