diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 51be34efde..8680b7542a 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -694,6 +694,7 @@ class CLIPTextModel(CLIPPreTrainedModel): # END OF THE CLIP MODEL COPY-PASTE ##################### + def _extract_into_tensor(arr, timesteps, broadcast_shape): """ Extract values from a 1-D numpy array for a batch of indices. diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 3ad6bc8146..743104e658 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -475,13 +475,15 @@ class GradTTSPipeline(DiffusionPipeline): xt = z * y_mask h = 1.0 / num_inference_steps + # (Patrick: TODO) for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps): + t_new = num_inference_steps - t - 1 t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) - time = t.unsqueeze(-1).unsqueeze(-1) residual = self.unet(xt, t, mu_y, y_mask, speaker_id) - xt = self.noise_scheduler.step(xt, residual, mu_y, h, time) + scheduler_residual = residual - mu_y + xt + xt = self.noise_scheduler.step(scheduler_residual, xt, t_new, num_inference_steps) xt = xt * y_mask return xt[:, :, :y_max_length] diff --git a/src/diffusers/schedulers/scheduling_grad_tts.py b/src/diffusers/schedulers/scheduling_grad_tts.py index 94b5f2ac55..4dc6638de3 100644 --- a/src/diffusers/schedulers/scheduling_grad_tts.py +++ b/src/diffusers/schedulers/scheduling_grad_tts.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np + from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin @@ -19,29 +21,34 @@ from .scheduling_utils import SchedulerMixin class GradTTSScheduler(SchedulerMixin, ConfigMixin): def __init__( self, - timesteps=1000, - beta_start=0.0001, - beta_end=0.02, + beta_start=0.05, + beta_end=20, tensor_format="np", ): super().__init__() self.register_to_config( - timesteps=timesteps, beta_start=beta_start, beta_end=beta_end, ) self.set_format(tensor_format=tensor_format) + self.betas = None - def sample_noise(self, timestep): - noise = self.beta_start + (self.beta_end - self.beta_start) * timestep - return noise + def get_timesteps(self, num_inference_steps): + return np.array([(t + 0.5) / num_inference_steps for t in range(num_inference_steps)]) - def step(self, xt, residual, mu, h, timestep): - noise_t = self.sample_noise(timestep) - dxt = 0.5 * (mu - xt - residual) - dxt = dxt * noise_t * h - xt = xt - dxt - return xt + def set_betas(self, num_inference_steps): + timesteps = self.get_timesteps(num_inference_steps) + self.betas = np.array([self.beta_start + (self.beta_end - self.beta_start) * t for t in timesteps]) - def __len__(self): - return len(self.config.timesteps) + def step(self, residual, sample, t, num_inference_steps): + # This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix) + if self.betas is None: + self.set_betas(num_inference_steps) + + beta_t = self.betas[t] + beta_t_deriv = beta_t / num_inference_steps + + sample_deriv = residual * beta_t_deriv / 2 + + sample = sample + sample_deriv + return sample diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index f75bce88a9..db4ed6eb02 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -31,6 +31,7 @@ from diffusers import ( GlideSuperResUNetModel, GlideTextToImageUNetModel, GradTTSPipeline, + GradTTSScheduler, LatentDiffusionPipeline, PNDMPipeline, PNDMScheduler, @@ -705,6 +706,8 @@ class PipelineTesterMixin(unittest.TestCase): def test_grad_tts(self): model_id = "fusing/grad-tts-libri-tts" grad_tts = GradTTSPipeline.from_pretrained(model_id) + noise_scheduler = GradTTSScheduler() + grad_tts.noise_scheduler = noise_scheduler text = "Hello world, I missed you so much." generator = torch.manual_seed(0)