1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Add GradTTS

Add GradTTS
This commit is contained in:
Suraj Patil
2022-06-16 18:28:13 +02:00
committed by GitHub
7 changed files with 559 additions and 3 deletions

View File

@@ -10,6 +10,6 @@ from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from .models.unet_grad_tts import UNetGradTTSModel
from .models.unet_ldm import UNetLDMModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler

View File

@@ -4,7 +4,7 @@ import torch
try:
from einops import rearrange, repeat
from einops import rearrange
except:
print("Einops is not installed")
pass
@@ -144,9 +144,11 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.pe_scale = pe_scale
if n_spks > 1:
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
self.spk_mlp = torch.nn.Sequential(
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
)
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
@@ -189,6 +191,10 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
def forward(self, x, mask, mu, t, spk=None):
if self.n_spks > 1:
# Get speaker embedding
spk = self.spk_emb(spk)
if not isinstance(spk, type(None)):
s = self.spk_mlp(spk)

View File

@@ -1,6 +1,7 @@
from .pipeline_bddm import BDDM
from .pipeline_ddim import DDIM
from .pipeline_ddpm import DDPM
from .pipeline_grad_tts import GradTTS
try:

View File

@@ -0,0 +1,435 @@
# tokenizer
import os
import re
from shutil import copyfile
import torch
try:
from transformers import PreTrainedTokenizer
except:
print("transformers is not installed")
try:
from unidecode import unidecode
except:
print("unidecode is not installed")
pass
try:
import inflect
except:
print("inflect is not installed")
pass
valid_symbols = [
"AA",
"AA0",
"AA1",
"AA2",
"AE",
"AE0",
"AE1",
"AE2",
"AH",
"AH0",
"AH1",
"AH2",
"AO",
"AO0",
"AO1",
"AO2",
"AW",
"AW0",
"AW1",
"AW2",
"AY",
"AY0",
"AY1",
"AY2",
"B",
"CH",
"D",
"DH",
"EH",
"EH0",
"EH1",
"EH2",
"ER",
"ER0",
"ER1",
"ER2",
"EY",
"EY0",
"EY1",
"EY2",
"F",
"G",
"HH",
"IH",
"IH0",
"IH1",
"IH2",
"IY",
"IY0",
"IY1",
"IY2",
"JH",
"K",
"L",
"M",
"N",
"NG",
"OW",
"OW0",
"OW1",
"OW2",
"OY",
"OY0",
"OY1",
"OY2",
"P",
"R",
"S",
"SH",
"T",
"TH",
"UH",
"UH0",
"UH1",
"UH2",
"UW",
"UW0",
"UW1",
"UW2",
"V",
"W",
"Y",
"Z",
"ZH",
]
_valid_symbol_set = set(valid_symbols)
def intersperse(lst, item):
# Adds blank symbol
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
class CMUDict:
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding="latin-1") as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
self._entries = entries
def __len__(self):
return len(self._entries)
def lookup(self, word):
return self._entries.get(word.upper())
_alt_re = re.compile(r"\([0-9]+\)")
def _parse_cmudict(file):
cmudict = {}
for line in file:
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
parts = line.split(" ")
word = re.sub(_alt_re, "", parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
cmudict[word].append(pronunciation)
else:
cmudict[word] = [pronunciation]
return cmudict
def _get_pronunciation(s):
parts = s.strip().split(" ")
for part in parts:
if part not in _valid_symbol_set:
return None
return " ".join(parts)
_whitespace_re = re.compile(r"\s+")
_abbreviations = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def expand_numbers(text):
return normalize_numbers(text)
def lowercase(text):
return text.lower()
def collapse_whitespace(text):
return re.sub(_whitespace_re, " ", text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
return text
def english_cleaners(text):
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)
text = expand_abbreviations(text)
text = collapse_whitespace(text)
return text
_inflect = inflect.engine()
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"[0-9]+")
def _remove_commas(m):
return m.group(1).replace(",", "")
def _expand_decimal_point(m):
return m.group(1).replace(".", " point ")
def _expand_dollars(m):
match = m.group(1)
parts = match.split(".")
if len(parts) > 2:
return match + " dollars"
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = "dollar" if dollars == 1 else "dollars"
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = "dollar" if dollars == 1 else "dollars"
return "%s %s" % (dollars, dollar_unit)
elif cents:
cent_unit = "cent" if cents == 1 else "cents"
return "%s %s" % (cents, cent_unit)
else:
return "zero dollars"
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return "two thousand"
elif num > 2000 and num < 2010:
return "two thousand " + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + " hundred"
else:
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
else:
return _inflect.number_to_words(num, andword="")
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r"\1 pounds", text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
""" from https://github.com/keithito/tacotron """
_pad = "_"
_punctuation = "!'(),.:;? "
_special = "-"
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
# Prepend "@" to ARPAbet symbols to ensure uniqueness:
_arpabet = ["@" + s for s in valid_symbols]
# Export all symbols:
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
def get_arpabet(word, dictionary):
word_arpabet = dictionary.lookup(word)
if word_arpabet is not None:
return "{" + word_arpabet[0] + "}"
else:
return word
def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
dictionary: arpabet class with arpabet dictionary
Returns:
List of integers corresponding to the symbols in the text
"""
sequence = []
space = _symbols_to_sequence(" ")
# Check for curly braces and treat their contents as ARPAbet:
while len(text):
m = _curly_re.match(text)
if not m:
clean_text = _clean_text(text, cleaner_names)
if dictionary is not None:
clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")]
for i in range(len(clean_text)):
t = clean_text[i]
if t.startswith("{"):
sequence += _arpabet_to_sequence(t[1:-1])
else:
sequence += _symbols_to_sequence(t)
sequence += space
else:
sequence += _symbols_to_sequence(clean_text)
break
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
# remove trailing space
if dictionary is not None:
sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
return sequence
def sequence_to_text(sequence):
"""Converts a sequence of IDs back to a string"""
result = ""
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == "@":
s = "{%s}" % s[1:]
result += s
return result.replace("}{", " ")
def _clean_text(text, cleaner_names):
for cleaner in cleaner_names:
text = cleaner(text)
return text
def _symbols_to_sequence(symbols):
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
def _arpabet_to_sequence(text):
return _symbols_to_sequence(["@" + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s != "_" and s != "~"
VOCAB_FILES_NAMES = {
"dict_file": "dict_file.txt",
}
class GradTTSTokenizer(PreTrainedTokenizer):
vocab_files_names = VOCAB_FILES_NAMES
def __init__(self, dict_file, **kwargs):
super().__init__(**kwargs)
self.cmu = CMUDict(dict_file)
self.dict_file = dict_file
def __call__(self, text):
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None]
x_lengths = torch.LongTensor([x.shape[-1]])
return x, x_lengths
def save_vocabulary(self, save_directory: str, filename_prefix=None):
dict_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"]
)
copyfile(self.dict_file, dict_file)
return (dict_file,)

View File

@@ -5,9 +5,13 @@ import math
import torch
from torch import nn
import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
def sequence_mask(length, max_length=None):
if max_length is None:
@@ -414,3 +418,63 @@ class TextEncoder(ModelMixin, ConfigMixin):
logw = self.proj_w(x_dp, x_mask)
return mu, logw, x_mask
class GradTTS(DiffusionPipeline):
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(
unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer
)
@torch.no_grad()
def __call__(
self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None
):
if torch_device is None:
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.unet.to(torch_device)
self.text_encoder.to(torch_device)
x, x_lengths = self.tokenizer(text)
x = x.to(torch_device)
x_lengths = x_lengths.to(torch_device)
if speaker_id is not None:
speaker_id = torch.LongTensor([speaker_id]).to(torch_device)
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = int(y_lengths.max())
y_max_length_ = fix_len_compatibility(y_max_length)
# Using obtained durations `w` construct alignment map `attn`
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
# Sample latent representation from terminal distribution N(mu_y, I)
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
xt = z * y_mask
h = 1.0 / num_inference_steps
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1)
residual = self.unet(xt, y_mask, mu_y, t, speaker_id)
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
xt = xt * y_mask
return xt[:, :, :y_max_length]

View File

@@ -19,5 +19,6 @@
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_grad_tts import GradTTSScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_utils import SchedulerMixin

View File

@@ -0,0 +1,49 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
class GradTTSScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
tensor_format="np",
):
super().__init__()
self.register(
timesteps=timesteps,
beta_start=beta_start,
beta_end=beta_end,
)
self.timesteps = int(timesteps)
self.set_format(tensor_format=tensor_format)
def sample_noise(self, timestep):
noise = self.beta_start + (self.beta_end - self.beta_start) * timestep
return noise
def step(self, xt, residual, mu, h, timestep):
noise_t = self.sample_noise(timestep)
dxt = 0.5 * (mu - xt - residual)
dxt = dxt * noise_t * h
xt = xt - dxt
return xt
def __len__(self):
return self.timesteps