mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update grad tts pipeline
This commit is contained in:
@@ -357,7 +357,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
||||
self.window_size = window_size
|
||||
self.spk_emb_dim = spk_emb_dim
|
||||
self.n_spks = n_spks
|
||||
|
||||
|
||||
self.emb = torch.nn.Embedding(n_vocab, n_channels)
|
||||
torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5)
|
||||
|
||||
@@ -371,7 +371,7 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
||||
self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp,
|
||||
kernel_size, p_dropout)
|
||||
|
||||
def forward(self, x, x_lengths, spk=None):
|
||||
def forward(self, x, x_lengths, spk=None):
|
||||
x = self.emb(x) * math.sqrt(self.n_channels)
|
||||
x = torch.transpose(x, 1, -1)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
@@ -385,19 +385,47 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
||||
x_dp = torch.detach(x)
|
||||
logw = self.proj_w(x_dp, x_mask)
|
||||
|
||||
return mu, logw, x_mask
|
||||
return mu, logw, x_mask, spk
|
||||
|
||||
|
||||
class GradTTS(DiffusionPipeline):
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(diffwave=unet, noise_scheduler=noise_scheduler)
|
||||
self.register_modules(diffwave=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, text, speaker_id, num_inference_steps, generator, torch_device=None):
|
||||
def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None):
|
||||
if torch_device is None:
|
||||
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
pass
|
||||
|
||||
x, x_lengths = self.tokenizer(text)
|
||||
|
||||
if speaker_id is not None:
|
||||
speaker_id= torch.longTensor([speaker_id])
|
||||
|
||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||
mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
|
||||
|
||||
w = torch.exp(logw) * x_mask
|
||||
w_ceil = torch.ceil(w) * length_scale
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_max_length = int(y_lengths.max())
|
||||
y_max_length_ = fix_len_compatibility(y_max_length)
|
||||
|
||||
# Using obtained durations `w` construct alignment map `attn`
|
||||
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
|
||||
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
||||
|
||||
# Align encoded text and get mu_y
|
||||
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
||||
mu_y = mu_y.transpose(1, 2)
|
||||
encoder_outputs = mu_y[:, :, :y_max_length]
|
||||
|
||||
# Sample latent representation from terminal distribution N(mu_y, I)
|
||||
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user