1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

update grad tts pipeline

This commit is contained in:
patil-suraj
2022-06-16 16:48:23 +02:00
parent 71ecc7aed8
commit 2d8d82f93e

View File

@@ -357,7 +357,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
self.window_size = window_size
self.spk_emb_dim = spk_emb_dim
self.n_spks = n_spks
self.emb = torch.nn.Embedding(n_vocab, n_channels)
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
@@ -371,7 +371,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp,
kernel_size, p_dropout)
def forward(self, x, x_lengths, spk=None):
def forward(self, x, x_lengths, spk=None):
x = self.emb(x) * math.sqrt(self.n_channels)
x = torch.transpose(x, 1, -1)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
@@ -385,19 +385,47 @@ class TextEncoder(ModelMixin, ConfigMixin):
x_dp = torch.detach(x)
logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask
return mu, logw, x_mask, spk
class GradTTS(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(diffwave=unet, noise_scheduler=noise_scheduler)
self.register_modules(diffwave=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer)
@torch.no_grad()
def __call__(self, text, speaker_id, num_inference_steps, generator, torch_device=None):
def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None):
if torch_device is None:
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pass
x, x_lengths = self.tokenizer(text)
if speaker_id is not None:
speaker_id= torch.longTensor([speaker_id])
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = int(y_lengths.max())
y_max_length_ = fix_len_compatibility(y_max_length)
# Using obtained durations `w` construct alignment map `attn`
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
encoder_outputs = mu_y[:, :, :y_max_length]
# Sample latent representation from terminal distribution N(mu_y, I)
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature