From 2d8d82f93e95b238ef2f5dc617ebc9f0786f9efb Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 16:48:23 +0200 Subject: [PATCH] update grad tts pipeline --- src/diffusers/pipelines/pipeline_grad_tts.py | 42 ++++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index c32d77e762..d5a23d9677 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -357,7 +357,7 @@ class TextEncoder(ModelMixin, ConfigMixin): self.window_size = window_size self.spk_emb_dim = spk_emb_dim self.n_spks = n_spks - + self.emb = torch.nn.Embedding(n_vocab, n_channels) torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) @@ -371,7 +371,7 @@ class TextEncoder(ModelMixin, ConfigMixin): self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout) - def forward(self, x, x_lengths, spk=None): + def forward(self, x, x_lengths, spk=None): x = self.emb(x) * math.sqrt(self.n_channels) x = torch.transpose(x, 1, -1) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) @@ -385,19 +385,47 @@ class TextEncoder(ModelMixin, ConfigMixin): x_dp = torch.detach(x) logw = self.proj_w(x_dp, x_mask) - return mu, logw, x_mask + return mu, logw, x_mask, spk class GradTTS(DiffusionPipeline): - def __init__(self, unet, noise_scheduler): + def __init__(self, unet, text_encoder, noise_scheduler, tokenizer): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") - self.register_modules(diffwave=unet, noise_scheduler=noise_scheduler) + self.register_modules(diffwave=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer) @torch.no_grad() - def __call__(self, text, speaker_id, num_inference_steps, generator, torch_device=None): + def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None): if torch_device is None: torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - pass + + x, x_lengths = self.tokenizer(text) + + if speaker_id is not None: + speaker_id= torch.longTensor([speaker_id]) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.text_encoder(x, x_lengths) + + w = torch.exp(logw) * x_mask + w_ceil = torch.ceil(w) * length_scale + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = int(y_lengths.max()) + y_max_length_ = fix_len_compatibility(y_max_length) + + # Using obtained durations `w` construct alignment map `attn` + y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + + # Align encoded text and get mu_y + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + encoder_outputs = mu_y[:, :, :y_max_length] + + # Sample latent representation from terminal distribution N(mu_y, I) + z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature + + \ No newline at end of file