diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7db65141ac..f93cd2943f 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -9,6 +9,6 @@ from .models.unet import UNetModel from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_ldm import UNetLDMModel from .pipeline_utils import DiffusionPipeline -from .pipelines import DDIM, DDPM, GLIDE, BDDMPipeline, LatentDiffusion +from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, BDDM from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a6a37154cb..ad42aead20 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,5 +1,5 @@ -from .pipeline_bddm import BDDMPipeline from .pipeline_ddim import DDIM from .pipeline_ddpm import DDPM from .pipeline_glide import GLIDE from .pipeline_latent_diffusion import LatentDiffusion +from .pipeline_bddm import BDDM diff --git a/src/diffusers/pipelines/pipeline_bddm.py b/src/diffusers/pipelines/pipeline_bddm.py index f61c8d4cf6..ee9e628f4d 100644 --- a/src/diffusers/pipelines/pipeline_bddm.py +++ b/src/diffusers/pipelines/pipeline_bddm.py @@ -271,20 +271,21 @@ class DiffWave(ModelMixin, ConfigMixin): return self.final_conv(x) -class BDDMPipeline(DiffusionPipeline): +class BDDM(DiffusionPipeline): def __init__(self, diffwave, noise_scheduler): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") self.register_modules(diffwave=diffwave, noise_scheduler=noise_scheduler) @torch.no_grad() - def __call__(self, mel_spectrogram, generator): + def __call__(self, mel_spectrogram, generator, torch_device=None): if torch_device is None: torch_device = "cuda" if torch.cuda.is_available() else "cpu" self.diffwave.to(torch_device) - - audio_length = mel_spectrogram.size(-1) * self.config.hop_len + + mel_spectrogram = mel_spectrogram.to(torch_device) + audio_length = mel_spectrogram.size(-1) * 256 audio_size = (1, 1, audio_length) # Sample gaussian noise to begin loop @@ -294,9 +295,8 @@ class BDDMPipeline(DiffusionPipeline): num_prediction_steps = len(self.noise_scheduler) for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): # 1. predict noise residual - with torch.no_grad(): - t = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device) - residual = self.diffwave(audio, mel_spectrogram, t) + ts = (torch.tensor(timestep_values[t]) * torch.ones((1, 1))).to(torch_device) + residual = self.diffwave((audio, mel_spectrogram, ts)) # 2. predict previous mean of audio x_t-1 pred_prev_audio = self.noise_scheduler.step(residual, audio, t) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 7332e9a912..88e4725e75 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -42,9 +42,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): self.timestep_values = timestep_values # save the fixed timestep values for BDDM self.clip_image = clip_predicted_image - if trained_betas is not None: - self.betas = np.asarray(trained_betas) - elif beta_schedule == "linear": + if beta_schedule == "linear": self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) elif beta_schedule == "squaredcos_cap_v2": # GLIDE cosine schedule diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 5e4612494f..d5a686b91f 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -26,6 +26,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): beta_start=0.0001, beta_end=0.02, beta_schedule="linear", + trained_betas=None, + timestep_values=None, variance_type="fixed_small", clip_predicted_image=True, tensor_format="np", @@ -36,14 +38,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): beta_start=beta_start, beta_end=beta_end, beta_schedule=beta_schedule, + trained_betas=trained_betas, + timestep_values=timestep_values, variance_type=variance_type, clip_predicted_image=clip_predicted_image, ) self.timesteps = int(timesteps) + self.timestep_values = timestep_values # save the fixed timestep values for BDDM self.clip_image = clip_predicted_image self.variance_type = variance_type - if beta_schedule == "linear": + if trained_betas is not None: + self.betas = np.asarray(trained_betas) + elif beta_schedule == "linear": self.betas = linear_beta_schedule(timesteps, beta_start=beta_start, beta_end=beta_end) elif beta_schedule == "squaredcos_cap_v2": # GLIDE cosine schedule