From 986cc9b2f4cdabbb779c1991887d1d4f8e5880c5 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 14:08:41 +0200 Subject: [PATCH 01/13] add tokenizer --- src/diffusers/pipelines/grad_tts_utils.py | 341 ++++++++++++++++++++++ 1 file changed, 341 insertions(+) create mode 100644 src/diffusers/pipelines/grad_tts_utils.py diff --git a/src/diffusers/pipelines/grad_tts_utils.py b/src/diffusers/pipelines/grad_tts_utils.py new file mode 100644 index 0000000000..bc7c883d26 --- /dev/null +++ b/src/diffusers/pipelines/grad_tts_utils.py @@ -0,0 +1,341 @@ +# tokenizer + +import re + +import torch +from transformers import PreTrainedTokenizer + +try: + from unidecode import unidecode +except: + print("unidecode is not installed") + pass + +try: + import inflect +except: + print("inflect is not installed") + pass + + +valid_symbols = [ + 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', + 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', + 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', + 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', + 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', + 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', + 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' +] + +_valid_symbol_set = set(valid_symbols) + +def intersperse(lst, item): + # Adds blank symbol + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +class CMUDict: + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding='latin-1') as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} + self._entries = entries + + def __len__(self): + return len(self._entries) + + def lookup(self, word): + return self._entries.get(word.upper()) + + +_alt_re = re.compile(r'\([0-9]+\)') + + +def _parse_cmudict(file): + cmudict = {} + for line in file: + if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): + parts = line.split(' ') + word = re.sub(_alt_re, '', parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict + + +def _get_pronunciation(s): + parts = s.strip().split(' ') + for part in parts: + if part not in _valid_symbol_set: + return None + return ' '.join(parts) + + + +_whitespace_re = re.compile(r'\s+') + +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text + + + + + + +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words(num, andword='', zero='oh', + group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text + +""" from https://github.com/keithito/tacotron """ + + +_pad = '_' +_punctuation = '!\'(),.:;? ' +_special = '-' +_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' + +# Prepend "@" to ARPAbet symbols to ensure uniqueness: +_arpabet = ['@' + s for s in valid_symbols] + +# Export all symbols: +symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet + + +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + +_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') + + +def get_arpabet(word, dictionary): + word_arpabet = dictionary.lookup(word) + if word_arpabet is not None: + return "{" + word_arpabet[0] + "}" + else: + return word + + +def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + dictionary: arpabet class with arpabet dictionary + + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [] + space = _symbols_to_sequence(' ') + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = _curly_re.match(text) + if not m: + clean_text = _clean_text(text, cleaner_names) + if dictionary is not None: + clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")] + for i in range(len(clean_text)): + t = clean_text[i] + if t.startswith("{"): + sequence += _arpabet_to_sequence(t[1:-1]) + else: + sequence += _symbols_to_sequence(t) + sequence += space + else: + sequence += _symbols_to_sequence(clean_text) + break + sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) + sequence += _arpabet_to_sequence(m.group(2)) + text = m.group(3) + + # remove trailing space + if dictionary is not None: + sequence = sequence[:-1] if sequence[-1] == space[0] else sequence + return sequence + + +def sequence_to_text(sequence): + '''Converts a sequence of IDs back to a string''' + result = '' + for symbol_id in sequence: + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] + # Enclose ARPAbet back in curly braces: + if len(s) > 1 and s[0] == '@': + s = '{%s}' % s[1:] + result += s + return result.replace('}{', ' ') + + +def _clean_text(text, cleaner_names): + for cleaner in cleaner_names: + text = cleaner(text) + return text + + +def _symbols_to_sequence(symbols): + return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] + + +def _arpabet_to_sequence(text): + return _symbols_to_sequence(['@' + s for s in text.split()]) + + +def _should_keep_symbol(s): + return s in _symbol_to_id and s != '_' and s != '~' + + +VOCAB_FILES_NAMES = { + "dict_file": "merges.txt", +} + +class GradTTSTokenizer(PreTrainedTokenizer): + vocab_files_names = VOCAB_FILES_NAMES + + def __init__(self, dict_file, **kwargs): + super().__init__(**kwargs) + self.cmu = CMUDict(dict_file) + + def __call__(self, text): + x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None] + x_lengths = torch.LongTensor([x.shape[-1]]) + return x.shape, x_lengths From 7b55d334d592037b1bb9b3389ce26828a3073407 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 14:08:53 +0200 Subject: [PATCH 02/13] being pipeline --- src/diffusers/pipelines/pipeline_grad_tts.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 2d8f694638..c32d77e762 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -7,6 +7,9 @@ from torch import nn from diffusers.configuration_utils import ConfigMixin from diffusers.modeling_utils import ModelMixin +from diffusers import DiffusionPipeline + +from .grad_tts_utils import text_to_sequence def sequence_mask(length, max_length=None): @@ -383,3 +386,18 @@ class TextEncoder(ModelMixin, ConfigMixin): logw = self.proj_w(x_dp, x_mask) return mu, logw, x_mask + + +class GradTTS(DiffusionPipeline): + def __init__(self, unet, noise_scheduler): + super().__init__() + noise_scheduler = noise_scheduler.set_format("pt") + self.register_modules(diffwave=unet, noise_scheduler=noise_scheduler) + + @torch.no_grad() + def __call__(self, text, speaker_id, num_inference_steps, generator, torch_device=None): + if torch_device is None: + torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + pass + + \ No newline at end of file From 3f2d46a14ec57b98b90ee8d9447ac23571d53922 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 16:47:04 +0200 Subject: [PATCH 03/13] fix tokenizer --- src/diffusers/pipelines/grad_tts_utils.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/grad_tts_utils.py b/src/diffusers/pipelines/grad_tts_utils.py index bc7c883d26..e41501c458 100644 --- a/src/diffusers/pipelines/grad_tts_utils.py +++ b/src/diffusers/pipelines/grad_tts_utils.py @@ -1,6 +1,8 @@ # tokenizer import re +import os +from shutil import copyfile import torch from transformers import PreTrainedTokenizer @@ -325,17 +327,27 @@ def _should_keep_symbol(s): VOCAB_FILES_NAMES = { - "dict_file": "merges.txt", + "dict_file": "dict_file.txt", } class GradTTSTokenizer(PreTrainedTokenizer): vocab_files_names = VOCAB_FILES_NAMES - + def __init__(self, dict_file, **kwargs): super().__init__(**kwargs) self.cmu = CMUDict(dict_file) + self.dict_file = dict_file def __call__(self, text): x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None] x_lengths = torch.LongTensor([x.shape[-1]]) - return x.shape, x_lengths + return x, x_lengths + + def save_vocabulary(self, save_directory: str, filename_prefix = None): + dict_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"] + ) + + copyfile(self.dict_file, dict_file) + + return (dict_file, ) From 71ecc7aed87718d8383acae71c29aab2008a26a2 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 16:48:00 +0200 Subject: [PATCH 04/13] add speaker emb in unet --- src/diffusers/models/unet_grad_tts.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index de2d6aa2f1..7bd6414b05 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -154,6 +154,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): self.pe_scale = pe_scale if n_spks > 1: + self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)) self.time_pos_emb = SinusoidalPosEmb(dim) @@ -189,6 +190,10 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): self.final_conv = torch.nn.Conv2d(dim, 1, 1) def forward(self, x, mask, mu, t, spk=None): + if self.n_spks > 1: + # Get speaker embedding + spk = self.spk_emb(spk) + if not isinstance(spk, type(None)): s = self.spk_mlp(spk) From 2d8d82f93e95b238ef2f5dc617ebc9f0786f9efb Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 16:48:23 +0200 Subject: [PATCH 05/13] update grad tts pipeline --- src/diffusers/pipelines/pipeline_grad_tts.py | 42 ++++++++++++++++---- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index c32d77e762..d5a23d9677 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -357,7 +357,7 @@ class TextEncoder(ModelMixin, ConfigMixin): self.window_size = window_size self.spk_emb_dim = spk_emb_dim self.n_spks = n_spks - + self.emb = torch.nn.Embedding(n_vocab, n_channels) torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) @@ -371,7 +371,7 @@ class TextEncoder(ModelMixin, ConfigMixin): self.proj_w = DurationPredictor(n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout) - def forward(self, x, x_lengths, spk=None): + def forward(self, x, x_lengths, spk=None): x = self.emb(x) * math.sqrt(self.n_channels) x = torch.transpose(x, 1, -1) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) @@ -385,19 +385,47 @@ class TextEncoder(ModelMixin, ConfigMixin): x_dp = torch.detach(x) logw = self.proj_w(x_dp, x_mask) - return mu, logw, x_mask + return mu, logw, x_mask, spk class GradTTS(DiffusionPipeline): - def __init__(self, unet, noise_scheduler): + def __init__(self, unet, text_encoder, noise_scheduler, tokenizer): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") - self.register_modules(diffwave=unet, noise_scheduler=noise_scheduler) + self.register_modules(diffwave=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer) @torch.no_grad() - def __call__(self, text, speaker_id, num_inference_steps, generator, torch_device=None): + def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None): if torch_device is None: torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - pass + + x, x_lengths = self.tokenizer(text) + + if speaker_id is not None: + speaker_id= torch.longTensor([speaker_id]) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.text_encoder(x, x_lengths) + + w = torch.exp(logw) * x_mask + w_ceil = torch.ceil(w) * length_scale + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = int(y_lengths.max()) + y_max_length_ = fix_len_compatibility(y_max_length) + + # Using obtained durations `w` construct alignment map `attn` + y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + + # Align encoded text and get mu_y + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + encoder_outputs = mu_y[:, :, :y_max_length] + + # Sample latent representation from terminal distribution N(mu_y, I) + z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature + + \ No newline at end of file From cc45831ec6dfdc7a115c06a2336120d0448937e7 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 17:10:36 +0200 Subject: [PATCH 06/13] add GradTTSScheduler --- src/diffusers/__init__.py | 2 +- src/diffusers/schedulers/__init__.py | 1 + .../schedulers/scheduling_grad_tts.py | 52 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/schedulers/scheduling_grad_tts.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 2f4d2ab6dc..7e04aa0ac8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -11,5 +11,5 @@ from .models.unet_ldm import UNetLDMModel from .models.unet_grad_tts import UNetGradTTSModel from .pipeline_utils import DiffusionPipeline from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM -from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler +from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler, GradTTSScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 5e9dcaf64e..9e1cd3edc8 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -20,4 +20,5 @@ from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler from .scheduling_pndm import PNDMScheduler +from .scheduling_grad_tts import GradTTSScheduler from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_grad_tts.py b/src/diffusers/schedulers/scheduling_grad_tts.py new file mode 100644 index 0000000000..11a557a3e1 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_grad_tts.py @@ -0,0 +1,52 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import numpy as np + +from ..configuration_utils import ConfigMixin +from .scheduling_utils import SchedulerMixin + + +class GradTTSScheduler(SchedulerMixin, ConfigMixin): + def __init__( + self, + timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + tensor_format="np", + ): + super().__init__() + self.register( + timesteps=timesteps, + beta_start=beta_start, + beta_end=beta_end, + ) + self.timesteps = int(timesteps) + + self.set_format(tensor_format=tensor_format) + + def sample_noise(self, timestep): + noise = self.beta_start + (self.beta_end - self.beta_start) * timestep + return noise + + def step(self, xt, residual, mu, h, timestep): + noise_t = self.sample_noise(timestep) + dxt = 0.5 * (mu - xt - residual) + dxt = dxt * noise_t * h + xt = xt - dxt + return xt + + def __len__(self): + return self.timesteps From cdf26c55f58e6dc552be7b2e44584630e7eab2f5 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 18:07:59 +0200 Subject: [PATCH 07/13] remove unused import --- src/diffusers/models/unet_grad_tts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 7bd6414b05..72009813ad 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -3,7 +3,7 @@ import math import torch try: - from einops import rearrange, repeat + from einops import rearrange except: print("Einops is not installed") pass From 800739361421dc559230faaec25ab6363f9ebf26 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 18:08:21 +0200 Subject: [PATCH 08/13] wrap transformers import with try/catch --- src/diffusers/pipelines/grad_tts_utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/grad_tts_utils.py b/src/diffusers/pipelines/grad_tts_utils.py index e41501c458..0e3eeb35ad 100644 --- a/src/diffusers/pipelines/grad_tts_utils.py +++ b/src/diffusers/pipelines/grad_tts_utils.py @@ -5,7 +5,11 @@ import os from shutil import copyfile import torch -from transformers import PreTrainedTokenizer + +try: + from transformers import PreTrainedTokenizer +except: + print("transformers is not installed") try: from unidecode import unidecode From 1d2551d716b3fd9f40e528f50902f7cb17dc4617 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 18:08:33 +0200 Subject: [PATCH 09/13] finish GradTTS pipeline --- src/diffusers/pipelines/pipeline_grad_tts.py | 27 +++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index d5a23d9677..9f6fdbba59 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -4,12 +4,13 @@ import math import torch from torch import nn +import tqdm from diffusers.configuration_utils import ConfigMixin from diffusers.modeling_utils import ModelMixin from diffusers import DiffusionPipeline -from .grad_tts_utils import text_to_sequence +from .grad_tts_utils import GradTTSTokenizer # flake8: noqa def sequence_mask(length, max_length=None): @@ -385,24 +386,29 @@ class TextEncoder(ModelMixin, ConfigMixin): x_dp = torch.detach(x) logw = self.proj_w(x_dp, x_mask) - return mu, logw, x_mask, spk + return mu, logw, x_mask class GradTTS(DiffusionPipeline): def __init__(self, unet, text_encoder, noise_scheduler, tokenizer): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") - self.register_modules(diffwave=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer) + self.register_modules(unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer) @torch.no_grad() def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None): if torch_device is None: torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.unet.to(torch_device) + self.text_encoder.to(torch_device) + x, x_lengths = self.tokenizer(text) + x = x.to(torch_device) + x_lengths = x_lengths.to(torch_device) if speaker_id is not None: - speaker_id= torch.longTensor([speaker_id]) + speaker_id= torch.LongTensor([speaker_id]).to(torch_device) # Get encoder_outputs `mu_x` and log-scaled token durations `logw` mu_x, logw, x_mask = self.text_encoder(x, x_lengths) @@ -426,6 +432,15 @@ class GradTTS(DiffusionPipeline): # Sample latent representation from terminal distribution N(mu_y, I) z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature + xt = z * y_mask + h = 1.0 / num_inference_steps + for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps): + t = (1.0 - (t + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) + time = t.unsqueeze(-1).unsqueeze(-1) + + residual = self.unet(xt, y_mask, mu_y, t, speaker_id) + + xt = self.noise_scheduler.step(xt, residual, mu_y, h, time) + xt = xt * y_mask - - \ No newline at end of file + return xt[:, :, :y_max_length] \ No newline at end of file From e26782759ca731c40fb926c04f53764d487713a2 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 18:14:01 +0200 Subject: [PATCH 10/13] add GradTTS in init --- src/diffusers/__init__.py | 2 +- src/diffusers/pipelines/__init__.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7e04aa0ac8..4e0b9b12c2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -10,6 +10,6 @@ from .models.unet_glide import GLIDEUNetModel, GLIDESuperResUNetModel, GLIDEText from .models.unet_ldm import UNetLDMModel from .models.unet_grad_tts import UNetGradTTSModel from .pipeline_utils import DiffusionPipeline -from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM +from .pipelines import DDIM, DDPM, GLIDE, LatentDiffusion, PNDM, BDDM, GradTTS from .schedulers import DDIMScheduler, DDPMScheduler, SchedulerMixin, PNDMScheduler, GradTTSScheduler from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index e0d2bf2e30..176335b85a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -4,3 +4,4 @@ from .pipeline_pndm import PNDM from .pipeline_glide import GLIDE from .pipeline_latent_diffusion import LatentDiffusion from .pipeline_bddm import BDDM +from .pipeline_grad_tts import GradTTS \ No newline at end of file From 5a3467e623c4d74c1f9f2a1239e5a6e0d91042fc Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 18:17:45 +0200 Subject: [PATCH 11/13] add default params for GradTTS --- src/diffusers/pipelines/pipeline_grad_tts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 9f6fdbba59..cd4f92dd90 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -396,7 +396,7 @@ class GradTTS(DiffusionPipeline): self.register_modules(unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer) @torch.no_grad() - def __call__(self, text, num_inference_steps, generator, temperature, length_scale, speaker_id=None, torch_device=None): + def __call__(self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None): if torch_device is None: torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -427,7 +427,6 @@ class GradTTS(DiffusionPipeline): # Align encoded text and get mu_y mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) mu_y = mu_y.transpose(1, 2) - encoder_outputs = mu_y[:, :, :y_max_length] # Sample latent representation from terminal distribution N(mu_y, I) z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature From ace07110c1696ccc47fa6eb8a36d153adb840e77 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 18:26:00 +0200 Subject: [PATCH 12/13] style --- src/diffusers/models/unet_grad_tts.py | 5 +- src/diffusers/pipelines/grad_tts_utils.py | 256 ++++++++++++------ src/diffusers/pipelines/pipeline_grad_tts.py | 44 +-- src/diffusers/schedulers/__init__.py | 2 +- .../schedulers/scheduling_grad_tts.py | 4 +- 5 files changed, 197 insertions(+), 114 deletions(-) diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 06fda8b473..6792177193 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -145,8 +145,9 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): if n_spks > 1: self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) - self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), - torch.nn.Linear(spk_emb_dim * 4, n_feats)) + self.spk_mlp = torch.nn.Sequential( + torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats) + ) self.time_pos_emb = SinusoidalPosEmb(dim) self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) diff --git a/src/diffusers/pipelines/grad_tts_utils.py b/src/diffusers/pipelines/grad_tts_utils.py index 0e3eeb35ad..a96fbc2d16 100644 --- a/src/diffusers/pipelines/grad_tts_utils.py +++ b/src/diffusers/pipelines/grad_tts_utils.py @@ -1,11 +1,12 @@ # tokenizer -import re import os +import re from shutil import copyfile import torch + try: from transformers import PreTrainedTokenizer except: @@ -25,17 +26,95 @@ except: valid_symbols = [ - 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', - 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', - 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', - 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', - 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', - 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', - 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' + "AA", + "AA0", + "AA1", + "AA2", + "AE", + "AE0", + "AE1", + "AE2", + "AH", + "AH0", + "AH1", + "AH2", + "AO", + "AO0", + "AO1", + "AO2", + "AW", + "AW0", + "AW1", + "AW2", + "AY", + "AY0", + "AY1", + "AY2", + "B", + "CH", + "D", + "DH", + "EH", + "EH0", + "EH1", + "EH2", + "ER", + "ER0", + "ER1", + "ER2", + "EY", + "EY0", + "EY1", + "EY2", + "F", + "G", + "HH", + "IH", + "IH0", + "IH1", + "IH2", + "IY", + "IY0", + "IY1", + "IY2", + "JH", + "K", + "L", + "M", + "N", + "NG", + "OW", + "OW0", + "OW1", + "OW2", + "OY", + "OY0", + "OY1", + "OY2", + "P", + "R", + "S", + "SH", + "T", + "TH", + "UH", + "UH0", + "UH1", + "UH2", + "UW", + "UW0", + "UW1", + "UW2", + "V", + "W", + "Y", + "Z", + "ZH", ] _valid_symbol_set = set(valid_symbols) + def intersperse(lst, item): # Adds blank symbol result = [item] * (len(lst) * 2 + 1) @@ -46,7 +125,7 @@ def intersperse(lst, item): class CMUDict: def __init__(self, file_or_path, keep_ambiguous=True): if isinstance(file_or_path, str): - with open(file_or_path, encoding='latin-1') as f: + with open(file_or_path, encoding="latin-1") as f: entries = _parse_cmudict(f) else: entries = _parse_cmudict(file_or_path) @@ -61,15 +140,15 @@ class CMUDict: return self._entries.get(word.upper()) -_alt_re = re.compile(r'\([0-9]+\)') +_alt_re = re.compile(r"\([0-9]+\)") def _parse_cmudict(file): cmudict = {} for line in file: - if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): - parts = line.split(' ') - word = re.sub(_alt_re, '', parts[0]) + if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): + parts = line.split(" ") + word = re.sub(_alt_re, "", parts[0]) pronunciation = _get_pronunciation(parts[1]) if pronunciation: if word in cmudict: @@ -80,36 +159,38 @@ def _parse_cmudict(file): def _get_pronunciation(s): - parts = s.strip().split(' ') + parts = s.strip().split(" ") for part in parts: if part not in _valid_symbol_set: return None - return ' '.join(parts) + return " ".join(parts) +_whitespace_re = re.compile(r"\s+") -_whitespace_re = re.compile(r'\s+') - -_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), -]] +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] def expand_abbreviations(text): @@ -127,7 +208,7 @@ def lowercase(text): def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text) + return re.sub(_whitespace_re, " ", text) def convert_to_ascii(text): @@ -156,46 +237,42 @@ def english_cleaners(text): return text - - - - _inflect = inflect.engine() -_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') -_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') -_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') -_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') -_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') -_number_re = re.compile(r'[0-9]+') +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") def _remove_commas(m): - return m.group(1).replace(',', '') + return m.group(1).replace(",", "") def _expand_decimal_point(m): - return m.group(1).replace('.', ' point ') + return m.group(1).replace(".", " point ") def _expand_dollars(m): match = m.group(1) - parts = match.split('.') + parts = match.split(".") if len(parts) > 2: - return match + ' dollars' + return match + " dollars" dollars = int(parts[0]) if parts[0] else 0 cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 if dollars and cents: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) elif dollars: - dollar_unit = 'dollar' if dollars == 1 else 'dollars' - return '%s %s' % (dollars, dollar_unit) + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) elif cents: - cent_unit = 'cent' if cents == 1 else 'cents' - return '%s %s' % (cents, cent_unit) + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) else: - return 'zero dollars' + return "zero dollars" def _expand_ordinal(m): @@ -206,37 +283,37 @@ def _expand_number(m): num = int(m.group(0)) if num > 1000 and num < 3000: if num == 2000: - return 'two thousand' + return "two thousand" elif num > 2000 and num < 2010: - return 'two thousand ' + _inflect.number_to_words(num % 100) + return "two thousand " + _inflect.number_to_words(num % 100) elif num % 100 == 0: - return _inflect.number_to_words(num // 100) + ' hundred' + return _inflect.number_to_words(num // 100) + " hundred" else: - return _inflect.number_to_words(num, andword='', zero='oh', - group=2).replace(', ', ' ') + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") else: - return _inflect.number_to_words(num, andword='') + return _inflect.number_to_words(num, andword="") def normalize_numbers(text): text = re.sub(_comma_number_re, _remove_commas, text) - text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_pounds_re, r"\1 pounds", text) text = re.sub(_dollars_re, _expand_dollars, text) text = re.sub(_decimal_number_re, _expand_decimal_point, text) text = re.sub(_ordinal_re, _expand_ordinal, text) text = re.sub(_number_re, _expand_number, text) return text + """ from https://github.com/keithito/tacotron """ -_pad = '_' -_punctuation = '!\'(),.:;? ' -_special = '-' -_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' +_pad = "_" +_punctuation = "!'(),.:;? " +_special = "-" +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" # Prepend "@" to ARPAbet symbols to ensure uniqueness: -_arpabet = ['@' + s for s in valid_symbols] +_arpabet = ["@" + s for s in valid_symbols] # Export all symbols: symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet @@ -245,7 +322,7 @@ symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpab _symbol_to_id = {s: i for i, s in enumerate(symbols)} _id_to_symbol = {i: s for i, s in enumerate(symbols)} -_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') +_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") def get_arpabet(word, dictionary): @@ -257,7 +334,7 @@ def get_arpabet(word, dictionary): def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): - '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. The text can optionally have ARPAbet sequences enclosed in curly braces embedded in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." @@ -269,9 +346,9 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): Returns: List of integers corresponding to the symbols in the text - ''' + """ sequence = [] - space = _symbols_to_sequence(' ') + space = _symbols_to_sequence(" ") # Check for curly braces and treat their contents as ARPAbet: while len(text): m = _curly_re.match(text) @@ -292,7 +369,7 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) sequence += _arpabet_to_sequence(m.group(2)) text = m.group(3) - + # remove trailing space if dictionary is not None: sequence = sequence[:-1] if sequence[-1] == space[0] else sequence @@ -300,16 +377,16 @@ def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): def sequence_to_text(sequence): - '''Converts a sequence of IDs back to a string''' - result = '' + """Converts a sequence of IDs back to a string""" + result = "" for symbol_id in sequence: if symbol_id in _id_to_symbol: s = _id_to_symbol[symbol_id] # Enclose ARPAbet back in curly braces: - if len(s) > 1 and s[0] == '@': - s = '{%s}' % s[1:] + if len(s) > 1 and s[0] == "@": + s = "{%s}" % s[1:] result += s - return result.replace('}{', ' ') + return result.replace("}{", " ") def _clean_text(text, cleaner_names): @@ -323,17 +400,18 @@ def _symbols_to_sequence(symbols): def _arpabet_to_sequence(text): - return _symbols_to_sequence(['@' + s for s in text.split()]) + return _symbols_to_sequence(["@" + s for s in text.split()]) def _should_keep_symbol(s): - return s in _symbol_to_id and s != '_' and s != '~' + return s in _symbol_to_id and s != "_" and s != "~" VOCAB_FILES_NAMES = { "dict_file": "dict_file.txt", } + class GradTTSTokenizer(PreTrainedTokenizer): vocab_files_names = VOCAB_FILES_NAMES @@ -341,17 +419,17 @@ class GradTTSTokenizer(PreTrainedTokenizer): super().__init__(**kwargs) self.cmu = CMUDict(dict_file) self.dict_file = dict_file - + def __call__(self, text): x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None] x_lengths = torch.LongTensor([x.shape[-1]]) return x, x_lengths - - def save_vocabulary(self, save_directory: str, filename_prefix = None): + + def save_vocabulary(self, save_directory: str, filename_prefix=None): dict_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"] ) copyfile(self.dict_file, dict_file) - - return (dict_file, ) + + return (dict_file,) diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index a18aa44714..246661be38 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -4,13 +4,13 @@ import math import torch from torch import nn -import tqdm +import tqdm +from diffusers import DiffusionPipeline from diffusers.configuration_utils import ConfigMixin from diffusers.modeling_utils import ModelMixin -from diffusers import DiffusionPipeline -from .grad_tts_utils import GradTTSTokenizer # flake8: noqa +from .grad_tts_utils import GradTTSTokenizer # flake8: noqa def sequence_mask(length, max_length=None): @@ -382,7 +382,7 @@ class TextEncoder(ModelMixin, ConfigMixin): self.window_size = window_size self.spk_emb_dim = spk_emb_dim self.n_spks = n_spks - + self.emb = torch.nn.Embedding(n_vocab, n_channels) torch.nn.init.normal_(self.emb.weight, 0.0, n_channels**-0.5) @@ -403,7 +403,7 @@ class TextEncoder(ModelMixin, ConfigMixin): n_channels + (spk_emb_dim if n_spks > 1 else 0), filter_channels_dp, kernel_size, p_dropout ) - def forward(self, x, x_lengths, spk=None): + def forward(self, x, x_lengths, spk=None): x = self.emb(x) * math.sqrt(self.n_channels) x = torch.transpose(x, 1, -1) x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) @@ -424,26 +424,30 @@ class GradTTS(DiffusionPipeline): def __init__(self, unet, text_encoder, noise_scheduler, tokenizer): super().__init__() noise_scheduler = noise_scheduler.set_format("pt") - self.register_modules(unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer) - + self.register_modules( + unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer + ) + @torch.no_grad() - def __call__(self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None): + def __call__( + self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None + ): if torch_device is None: torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - + self.unet.to(torch_device) self.text_encoder.to(torch_device) - + x, x_lengths = self.tokenizer(text) x = x.to(torch_device) x_lengths = x_lengths.to(torch_device) - + if speaker_id is not None: - speaker_id= torch.LongTensor([speaker_id]).to(torch_device) - + speaker_id = torch.LongTensor([speaker_id]).to(torch_device) + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` mu_x, logw, x_mask = self.text_encoder(x, x_lengths) - + w = torch.exp(logw) * x_mask w_ceil = torch.ceil(w) * length_scale y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() @@ -461,16 +465,16 @@ class GradTTS(DiffusionPipeline): # Sample latent representation from terminal distribution N(mu_y, I) z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature - + xt = z * y_mask h = 1.0 / num_inference_steps for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps): - t = (1.0 - (t + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) + t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) time = t.unsqueeze(-1).unsqueeze(-1) - + residual = self.unet(xt, y_mask, mu_y, t, speaker_id) - + xt = self.noise_scheduler.step(xt, residual, mu_y, h, time) xt = xt * y_mask - - return xt[:, :, :y_max_length] \ No newline at end of file + + return xt[:, :, :y_max_length] diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 9e1cd3edc8..47e5f6a1db 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -19,6 +19,6 @@ from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler -from .scheduling_pndm import PNDMScheduler from .scheduling_grad_tts import GradTTSScheduler +from .scheduling_pndm import PNDMScheduler from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_grad_tts.py b/src/diffusers/schedulers/scheduling_grad_tts.py index 11a557a3e1..2c154c87d6 100644 --- a/src/diffusers/schedulers/scheduling_grad_tts.py +++ b/src/diffusers/schedulers/scheduling_grad_tts.py @@ -36,11 +36,11 @@ class GradTTSScheduler(SchedulerMixin, ConfigMixin): self.timesteps = int(timesteps) self.set_format(tensor_format=tensor_format) - + def sample_noise(self, timestep): noise = self.beta_start + (self.beta_end - self.beta_start) * timestep return noise - + def step(self, xt, residual, mu, h, timestep): noise_t = self.sample_noise(timestep) dxt = 0.5 * (mu - xt - residual) From c2e48b23f8acf74e2b1efc435b53fb144daa28f6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 16 Jun 2022 18:27:47 +0200 Subject: [PATCH 13/13] remove unused import --- src/diffusers/schedulers/scheduling_grad_tts.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_grad_tts.py b/src/diffusers/schedulers/scheduling_grad_tts.py index 2c154c87d6..ca42921daa 100644 --- a/src/diffusers/schedulers/scheduling_grad_tts.py +++ b/src/diffusers/schedulers/scheduling_grad_tts.py @@ -11,9 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math - -import numpy as np from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin