diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4633576e84..dc69d8bf35 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -10,6 +10,6 @@ from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from .models.unet_grad_tts import UNetGradTTSModel from .models.unet_ldm import UNetLDMModel from .pipeline_utils import DiffusionPipeline -from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, LatentDiffusion -from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin +from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion +from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 08501c4b60..6792177193 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -4,7 +4,7 @@ import torch try: - from einops import rearrange, repeat + from einops import rearrange except: print("Einops is not installed") pass @@ -144,9 +144,11 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): self.pe_scale = pe_scale if n_spks > 1: + self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim) self.spk_mlp = torch.nn.Sequential( torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats) ) + self.time_pos_emb = SinusoidalPosEmb(dim) self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim)) @@ -189,6 +191,10 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): self.final_conv = torch.nn.Conv2d(dim, 1, 1) def forward(self, x, mask, mu, t, spk=None): + if self.n_spks > 1: + # Get speaker embedding + spk = self.spk_emb(spk) + if not isinstance(spk, type(None)): s = self.spk_mlp(spk) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 3d42bcd4ac..91af443c9f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,6 +1,7 @@ from .pipeline_bddm import BDDM from .pipeline_ddim import DDIM from .pipeline_ddpm import DDPM +from .pipeline_grad_tts import GradTTS try: diff --git a/src/diffusers/pipelines/grad_tts_utils.py b/src/diffusers/pipelines/grad_tts_utils.py new file mode 100644 index 0000000000..a96fbc2d16 --- /dev/null +++ b/src/diffusers/pipelines/grad_tts_utils.py @@ -0,0 +1,435 @@ +# tokenizer + +import os +import re +from shutil import copyfile + +import torch + + +try: + from transformers import PreTrainedTokenizer +except: + print("transformers is not installed") + +try: + from unidecode import unidecode +except: + print("unidecode is not installed") + pass + +try: + import inflect +except: + print("inflect is not installed") + pass + + +valid_symbols = [ + "AA", + "AA0", + "AA1", + "AA2", + "AE", + "AE0", + "AE1", + "AE2", + "AH", + "AH0", + "AH1", + "AH2", + "AO", + "AO0", + "AO1", + "AO2", + "AW", + "AW0", + "AW1", + "AW2", + "AY", + "AY0", + "AY1", + "AY2", + "B", + "CH", + "D", + "DH", + "EH", + "EH0", + "EH1", + "EH2", + "ER", + "ER0", + "ER1", + "ER2", + "EY", + "EY0", + "EY1", + "EY2", + "F", + "G", + "HH", + "IH", + "IH0", + "IH1", + "IH2", + "IY", + "IY0", + "IY1", + "IY2", + "JH", + "K", + "L", + "M", + "N", + "NG", + "OW", + "OW0", + "OW1", + "OW2", + "OY", + "OY0", + "OY1", + "OY2", + "P", + "R", + "S", + "SH", + "T", + "TH", + "UH", + "UH0", + "UH1", + "UH2", + "UW", + "UW0", + "UW1", + "UW2", + "V", + "W", + "Y", + "Z", + "ZH", +] + +_valid_symbol_set = set(valid_symbols) + + +def intersperse(lst, item): + # Adds blank symbol + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result + + +class CMUDict: + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding="latin-1") as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} + self._entries = entries + + def __len__(self): + return len(self._entries) + + def lookup(self, word): + return self._entries.get(word.upper()) + + +_alt_re = re.compile(r"\([0-9]+\)") + + +def _parse_cmudict(file): + cmudict = {} + for line in file: + if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): + parts = line.split(" ") + word = re.sub(_alt_re, "", parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict + + +def _get_pronunciation(s): + parts = s.strip().split(" ") + for part in parts: + if part not in _valid_symbol_set: + return None + return " ".join(parts) + + +_whitespace_re = re.compile(r"\s+") + +_abbreviations = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, " ", text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text + + +_inflect = inflect.engine() +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") +_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"[0-9]+") + + +def _remove_commas(m): + return m.group(1).replace(",", "") + + +def _expand_decimal_point(m): + return m.group(1).replace(".", " point ") + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split(".") + if len(parts) > 2: + return match + " dollars" + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = "dollar" if dollars == 1 else "dollars" + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = "dollar" if dollars == 1 else "dollars" + return "%s %s" % (dollars, dollar_unit) + elif cents: + cent_unit = "cent" if cents == 1 else "cents" + return "%s %s" % (cents, cent_unit) + else: + return "zero dollars" + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return "two thousand" + elif num > 2000 and num < 2010: + return "two thousand " + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + " hundred" + else: + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + else: + return _inflect.number_to_words(num, andword="") + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r"\1 pounds", text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text + + +""" from https://github.com/keithito/tacotron """ + + +_pad = "_" +_punctuation = "!'(),.:;? " +_special = "-" +_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + +# Prepend "@" to ARPAbet symbols to ensure uniqueness: +_arpabet = ["@" + s for s in valid_symbols] + +# Export all symbols: +symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet + + +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + +_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)") + + +def get_arpabet(word, dictionary): + word_arpabet = dictionary.lookup(word) + if word_arpabet is not None: + return "{" + word_arpabet[0] + "}" + else: + return word + + +def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + dictionary: arpabet class with arpabet dictionary + + Returns: + List of integers corresponding to the symbols in the text + """ + sequence = [] + space = _symbols_to_sequence(" ") + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = _curly_re.match(text) + if not m: + clean_text = _clean_text(text, cleaner_names) + if dictionary is not None: + clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")] + for i in range(len(clean_text)): + t = clean_text[i] + if t.startswith("{"): + sequence += _arpabet_to_sequence(t[1:-1]) + else: + sequence += _symbols_to_sequence(t) + sequence += space + else: + sequence += _symbols_to_sequence(clean_text) + break + sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) + sequence += _arpabet_to_sequence(m.group(2)) + text = m.group(3) + + # remove trailing space + if dictionary is not None: + sequence = sequence[:-1] if sequence[-1] == space[0] else sequence + return sequence + + +def sequence_to_text(sequence): + """Converts a sequence of IDs back to a string""" + result = "" + for symbol_id in sequence: + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] + # Enclose ARPAbet back in curly braces: + if len(s) > 1 and s[0] == "@": + s = "{%s}" % s[1:] + result += s + return result.replace("}{", " ") + + +def _clean_text(text, cleaner_names): + for cleaner in cleaner_names: + text = cleaner(text) + return text + + +def _symbols_to_sequence(symbols): + return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] + + +def _arpabet_to_sequence(text): + return _symbols_to_sequence(["@" + s for s in text.split()]) + + +def _should_keep_symbol(s): + return s in _symbol_to_id and s != "_" and s != "~" + + +VOCAB_FILES_NAMES = { + "dict_file": "dict_file.txt", +} + + +class GradTTSTokenizer(PreTrainedTokenizer): + vocab_files_names = VOCAB_FILES_NAMES + + def __init__(self, dict_file, **kwargs): + super().__init__(**kwargs) + self.cmu = CMUDict(dict_file) + self.dict_file = dict_file + + def __call__(self, text): + x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None] + x_lengths = torch.LongTensor([x.shape[-1]]) + return x, x_lengths + + def save_vocabulary(self, save_directory: str, filename_prefix=None): + dict_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"] + ) + + copyfile(self.dict_file, dict_file) + + return (dict_file,) diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 048db3785f..246661be38 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -5,9 +5,13 @@ import math import torch from torch import nn +import tqdm +from diffusers import DiffusionPipeline from diffusers.configuration_utils import ConfigMixin from diffusers.modeling_utils import ModelMixin +from .grad_tts_utils import GradTTSTokenizer # flake8: noqa + def sequence_mask(length, max_length=None): if max_length is None: @@ -414,3 +418,63 @@ class TextEncoder(ModelMixin, ConfigMixin): logw = self.proj_w(x_dp, x_mask) return mu, logw, x_mask + + +class GradTTS(DiffusionPipeline): + def __init__(self, unet, text_encoder, noise_scheduler, tokenizer): + super().__init__() + noise_scheduler = noise_scheduler.set_format("pt") + self.register_modules( + unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer + ) + + @torch.no_grad() + def __call__( + self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None + ): + if torch_device is None: + torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + self.unet.to(torch_device) + self.text_encoder.to(torch_device) + + x, x_lengths = self.tokenizer(text) + x = x.to(torch_device) + x_lengths = x_lengths.to(torch_device) + + if speaker_id is not None: + speaker_id = torch.LongTensor([speaker_id]).to(torch_device) + + # Get encoder_outputs `mu_x` and log-scaled token durations `logw` + mu_x, logw, x_mask = self.text_encoder(x, x_lengths) + + w = torch.exp(logw) * x_mask + w_ceil = torch.ceil(w) * length_scale + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_max_length = int(y_lengths.max()) + y_max_length_ = fix_len_compatibility(y_max_length) + + # Using obtained durations `w` construct alignment map `attn` + y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + + # Align encoded text and get mu_y + mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) + mu_y = mu_y.transpose(1, 2) + + # Sample latent representation from terminal distribution N(mu_y, I) + z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature + + xt = z * y_mask + h = 1.0 / num_inference_steps + for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps): + t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) + time = t.unsqueeze(-1).unsqueeze(-1) + + residual = self.unet(xt, y_mask, mu_y, t, speaker_id) + + xt = self.noise_scheduler.step(xt, residual, mu_y, h, time) + xt = xt * y_mask + + return xt[:, :, :y_max_length] diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 5e9dcaf64e..47e5f6a1db 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -19,5 +19,6 @@ from .classifier_free_guidance import ClassifierFreeGuidanceScheduler from .scheduling_ddim import DDIMScheduler from .scheduling_ddpm import DDPMScheduler +from .scheduling_grad_tts import GradTTSScheduler from .scheduling_pndm import PNDMScheduler from .scheduling_utils import SchedulerMixin diff --git a/src/diffusers/schedulers/scheduling_grad_tts.py b/src/diffusers/schedulers/scheduling_grad_tts.py new file mode 100644 index 0000000000..ca42921daa --- /dev/null +++ b/src/diffusers/schedulers/scheduling_grad_tts.py @@ -0,0 +1,49 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..configuration_utils import ConfigMixin +from .scheduling_utils import SchedulerMixin + + +class GradTTSScheduler(SchedulerMixin, ConfigMixin): + def __init__( + self, + timesteps=1000, + beta_start=0.0001, + beta_end=0.02, + tensor_format="np", + ): + super().__init__() + self.register( + timesteps=timesteps, + beta_start=beta_start, + beta_end=beta_end, + ) + self.timesteps = int(timesteps) + + self.set_format(tensor_format=tensor_format) + + def sample_noise(self, timestep): + noise = self.beta_start + (self.beta_end - self.beta_start) * timestep + return noise + + def step(self, xt, residual, mu, h, timestep): + noise_t = self.sample_noise(timestep) + dxt = 0.5 * (mu - xt - residual) + dxt = dxt * noise_t * h + xt = xt - dxt + return xt + + def __len__(self): + return self.timesteps