From ace07110c1696ccc47fa6eb8a36d153adb840e77 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 18:26:00 +0200 Subject: [PATCH] style --- src/diffusers/models/unet_grad_tts.py | 5 +- src/diffusers/pipelines/grad_tts_utils.py | 256 ++++++++++++------ src/diffusers/pipelines/pipeline_grad_tts.py | 44 +-- src/diffusers/schedulers/__init__.py | 2 +- .../schedulers/scheduling_grad_tts.py | 4 +- 5 files changed, 197 insertions(+), 114 deletions(-) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 06fda8b473..6792177193 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -145,8 +145,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): if n_spks > 1: self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) - self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), - torch.nn.Linear(spk_emb_dim * 4, n_feats)) + self.spk_mlp = torch.nn.Sequential( + torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats) + ) self.time_pos_emb = SinusoidalPosEmb(dim) self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) diff --git a/src/diffusers/pipelines/grad_tts_utils.py b/src/diffusers/pipelines/grad_tts_utils.py index 0e3eeb35ad..a96fbc2d16 100644 --- a/src/diffusers/pipelines/grad_tts_utils.py +++ b/src/diffusers/pipelines/grad_tts_utils.py @@ -1,11 +1,12 @@ # tokenizer -import re import os +import re from shutil import copyfile import torch + try: from transformers import PreTrainedTokenizer except: @@ -25,17 +26,95 @@ except: valid_symbols = [ - 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', - 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', - 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', - 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', - 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', - 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', - 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' + "AA", + "AA0", + "AA1", + "AA2", + "AE", + "AE0", + "AE1", + "AE2", + "AH", + "AH0", + "AH1", + "AH2", + "AO", + "AO0", + "AO1", + "AO2", + "AW", + "AW0", + "AW1", + "AW2", + "AY", + "AY0", + "AY1", + "AY2", + "B", + "CH", + "D", + "DH", + "EH", + "EH0", + "EH1", + "EH2", + "ER", + "ER0", + "ER1", + "ER2", + "EY", + "EY0", + "EY1", + "EY2", + "F", + "G", + "HH", + "IH", + "IH0", + "IH1", + "IH2", + "IY", + "IY0", + "IY1", + "IY2", + "JH", + "K", + "L", + "M", + "N", + "NG", + "OW", + "OW0", + "OW1", + "OW2", + "OY", + "OY0", + "OY1", + "OY2", + "P", + "R", + "S", + "SH", + "T", + "TH", + "UH", + "UH0", + "UH1", + "UH2", + "UW", + "UW0", + "UW1", + "UW2", + "V", + "W", + "Y", + "Z", + "ZH", ] _valid_symbol_set = set(valid_symbols) + def intersperse(lst, item): # Adds blank symbol result = [item] * (len(lst) * 2 + 1) @@ -46,7 +125,7 @@ def intersperse(lst, item): class CMUDict: def __init__(self, file_or_path, keep_ambiguous=True): if isinstance(file_or_path, str): - with open(file_or_path, encoding='latin-1') as f: + with open(file_or_path, encoding="latin-1") as f: entries = _parse_cmudict(f) else: entries = _parse_cmudict(file_or_path) @@ -61,15 +140,15 @@ class CMUDict: return self._entries.get(word.upper()) -_alt_re = re.compile(r'\([0-9]+\)') +_alt_re = re.compile(r"\([0-9]+\)") def _parse_cmudict(file): cmudict = {} for line in file: - if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): - parts = line.split(' ') - word = re.sub(_alt_re, '', parts[0]) + if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): + parts = line.split(" ") + word = re.sub(_alt_re, "", parts[0]) pronunciation = _get_pronunciation(parts[1]) if pronunciation: if word in cmudict: @@ -80,36 +159,38 @@ def _parse_cmudict(file): def _get_pronunciation(s): - parts = s.strip().split(' ') + parts = s.strip().split(" ") for part in parts: if part not in _valid_symbol_set: return None - return ' '.join(parts) + return " ".join(parts) +_whitespace_re = re.compile(r"\s+") -_whitespace_re = re.compile(r'\s+') - -_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), -]] +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] def expand_abbreviations(text): @@ -127,7 +208,7 @@ def lowercase(text): def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text) + return re.sub(_whitespace_re, " ", text) def convert_to_ascii(text): @@ -156,46 +237,42 @@ def english_cleaners(text): return text - - - - _inflect = inflect.engine() -_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') -_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') -_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') -_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') -_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') -_number_re = re.compile(r'[0-9]+') +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") def _remove_commas(m): - return m.group(1).replace(',', '') + return m.group(1).replace(",", "") def _expand_decimal_point(m): - return m.group(1).replace('.', ' point ') + return m.group(1).replace(".", " point ") def _expand_dollars(m): match = m.group(1) - parts = match.split('.') + parts = match.split(".") if len(parts) > 2: - return match + ' dollars' + return match + " dollars" dollars = int(parts[0]) if parts[0] else 0 cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 if dollars and cents: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) elif dollars: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - return '%s %s' % (dollars, dollar_unit) + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) elif cents: - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s' % (cents, cent_unit) + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) else: - return 'zero dollars' + return "zero dollars" def _expand_ordinal(m): @@ -206,37 +283,37 @@ def _expand_number(m): num = int(m.group(0)) if num > 1000 and num < 3000: if num == 2000: - return 'two thousand' + return "two thousand" elif num > 2000 and num < 2010: - return 'two thousand ' + _inflect.number_to_words(num % 100) + return "two thousand " + _inflect.number_to_words(num % 100) elif num % 100 == 0: - return _inflect.number_to_words(num // 100) + ' hundred' + return _inflect.number_to_words(num // 100) + " hundred" else: - return _inflect.number_to_words(num, andword='', zero='oh', - group=2).replace(', ', ' ') + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") else: - return _inflect.number_to_words(num, andword='') + return _inflect.number_to_words(num, andword="") def normalize_numbers(text): text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_pounds_re, r"\1 pounds", text) text = re.sub(_dollars_re, _expand_dollars, text) text = re.sub(_decimal_number_re, _expand_decimal_point, text) text = re.sub(_ordinal_re, _expand_ordinal, text) text = re.sub(_number_re, _expand_number, text) return text + """ from https://github.com/keithito/tacotron """ -_pad = '_' -_punctuation = '!\'(),.:;? ' -_special = '-' -_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +_pad = "_" +_punctuation = "!'(),.:;? " +_special = "-" +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" # Prepend "@" to ARPAbet symbols to ensure uniqueness: -_arpabet = ['@' + s for s in valid_symbols] +_arpabet = ["@" + s for s in valid_symbols] # Export all symbols: symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet @@ -245,7 +322,7 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpab _symbol_to_id = {s: i for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)} -_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') +_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") def get_arpabet(word, dictionary): @@ -257,7 +334,7 @@ def get_arpabet(word, dictionary): def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): - '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. The text can optionally have ARPAbet sequences enclosed in curly braces embedded in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." @@ -269,9 +346,9 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): Returns: List of integers corresponding to the symbols in the text - ''' + """ sequence = [] - space = _symbols_to_sequence(' ') + space = _symbols_to_sequence(" ") # Check for curly braces and treat their contents as ARPAbet: while len(text): m = _curly_re.match(text) @@ -292,7 +369,7 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) sequence += _arpabet_to_sequence(m.group(2)) text = m.group(3) - + # remove trailing space if dictionary is not None: sequence = sequence[:-1] if sequence[-1] == space[0] else sequence @@ -300,16 +377,16 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): def sequence_to_text(sequence): - '''Converts a sequence of IDs back to a string''' - result = '' + """Converts a sequence of IDs back to a string""" + result = "" for symbol_id in sequence: if symbol_id in _id_to_symbol: s = _id_to_symbol[symbol_id] # Enclose ARPAbet back in curly braces: - if len(s) > 1 and s[0] == '@': - s = '{%s}' % s[1:] + if len(s) > 1 and s[0] == "@": + s = "{%s}" % s[1:] result += s - return result.replace('}{', ' ') + return result.replace("}{", " ") def _clean_text(text, cleaner_names): @@ -323,17 +400,18 @@ def _symbols_to_sequence(symbols): def _arpabet_to_sequence(text): - return _symbols_to_sequence(['@' + s for s in text.split()]) + return _symbols_to_sequence(["@" + s for s in text.split()]) def _should_keep_symbol(s): - return s in _symbol_to_id and s != '_' and s != '~' + return s in _symbol_to_id and s != "_" and s != "~" VOCAB_FILES_NAMES = { "dict_file": "dict_file.txt", } + class GradTTSTokenizer(PreTrainedTokenizer): vocab_files_names = VOCAB_FILES_NAMES @@ -341,17 +419,17 @@ class GradTTSTokenizer(PreTrainedTokenizer): super().__init__(**kwargs) self.cmu = CMUDict(dict_file) self.dict_file = dict_file - + def __call__(self, text): x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None] x_lengths = torch.LongTensor([x.shape[-1]]) return x, x_lengths - - def save_vocabulary(self, save_directory: str, filename_prefix = None): + + def save_vocabulary(self, save_directory: str, filename_prefix=None): dict_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"] ) copyfile(self.dict_file, dict_file) - - return (dict_file, ) + + return (dict_file,) diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index a18aa44714..246661be38 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -4,13 +4,13 @@ import math import torch from torch import nn -import tqdm +import tqdm +from diffusers import DiffusionPipeline from diffusers.configuration_utils import ConfigMixin from diffusers.modeling_utils import ModelMixin -from diffusers import DiffusionPipeline -from .grad_tts_utils import GradTTSTokenizer # flake8: noqa +from .grad_tts_utils import GradTTSTokenizer # flake8: noqa def sequence_mask(length, max_length=None): @@ -382,7 +382,7 @@ class TextEncoder(ModelMixin, ConfigMixin): self.window_size = window_size self.spk_emb_dim = spk_emb_dim self.n_spks = n_spks - + self.emb = torch.nn.Embedding(n_vocab, n_channels) torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) @@ -403,7 +403,7 @@ class TextEncoder(ModelMixin, ConfigMixin): n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout ) - def forward(self, x, x_lengths, spk=None): + def forward(self, x, x_lengths, spk=None): x = self.emb(x) * math.sqrt(self.n_channels) x = torch.transpose(x, 1, -1) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) @@ -424,26 +424,30 @@ class GradTTS(DiffusionPipeline): def __init__(self, unet, text_encoder, noise_scheduler, tokenizer): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") - self.register_modules(unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer) - + self.register_modules( + unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer + ) + @torch.no_grad() - def __call__(self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None): + def __call__( + self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None + ): if torch_device is None: torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - + self.unet.to(torch_device) self.text_encoder.to(torch_device) - + x, x_lengths = self.tokenizer(text) x = x.to(torch_device) x_lengths = x_lengths.to(torch_device) - + if speaker_id is not None: - speaker_id= torch.LongTensor([speaker_id]).to(torch_device) - + speaker_id = torch.LongTensor([speaker_id]).to(torch_device) + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` mu_x, logw, x_mask = self.text_encoder(x, x_lengths) - + w = torch.exp(logw) * x_mask w_ceil = torch.ceil(w) * length_scale y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() @@ -461,16 +465,16 @@ class GradTTS(DiffusionPipeline): # Sample latent representation from terminal distribution N(mu_y, I) z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature - + xt = z * y_mask h = 1.0 / num_inference_steps for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps): - t = (1.0 - (t + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) + t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) time = t.unsqueeze(-1).unsqueeze(-1) - + residual = self.unet(xt, y_mask, mu_y, t, speaker_id) - + xt = self.noise_scheduler.step(xt, residual, mu_y, h, time) xt = xt * y_mask - - return xt[:, :, :y_max_length] \ No newline at end of file + + return xt[:, :, :y_max_length] diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 9e1cd3edc8..47e5f6a1db 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -19,6 +19,6 @@ from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler -from .scheduling_pndm import PNDMScheduler from .scheduling_grad_tts import GradTTSScheduler +from .scheduling_pndm import PNDMScheduler from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_grad_tts.py b/src/diffusers/schedulers/scheduling_grad_tts.py index 11a557a3e1..2c154c87d6 100644 --- a/src/diffusers/schedulers/scheduling_grad_tts.py +++ b/src/diffusers/schedulers/scheduling_grad_tts.py @@ -36,11 +36,11 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin): self.timesteps = int(timesteps) self.set_format(tensor_format=tensor_format) - + def sample_noise(self, timestep): noise = self.beta_start + (self.beta_end - self.beta_start) * timestep return noise - + def step(self, xt, residual, mu, h, timestep): noise_t = self.sample_noise(timestep) dxt = 0.5 * (mu - xt - residual)