mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Add GradTTS
Add GradTTS
This commit is contained in:
@@ -10,6 +10,6 @@ from .models.unet_glide import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from .models.unet_grad_tts import UNetGradTTSModel
|
||||
from .models.unet_ldm import UNetLDMModel
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, LatentDiffusion
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin
|
||||
from .pipelines import BDDM, DDIM, DDPM, GLIDE, PNDM, GradTTS, LatentDiffusion
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
|
||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
|
||||
@@ -4,7 +4,7 @@ import torch
|
||||
|
||||
|
||||
try:
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange
|
||||
except:
|
||||
print("Einops is not installed")
|
||||
pass
|
||||
@@ -144,9 +144,11 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
self.pe_scale = pe_scale
|
||||
|
||||
if n_spks > 1:
|
||||
self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
|
||||
self.spk_mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
|
||||
)
|
||||
|
||||
self.time_pos_emb = SinusoidalPosEmb(dim)
|
||||
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
|
||||
|
||||
@@ -189,6 +191,10 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
|
||||
|
||||
def forward(self, x, mask, mu, t, spk=None):
|
||||
if self.n_spks > 1:
|
||||
# Get speaker embedding
|
||||
spk = self.spk_emb(spk)
|
||||
|
||||
if not isinstance(spk, type(None)):
|
||||
s = self.spk_mlp(spk)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .pipeline_bddm import BDDM
|
||||
from .pipeline_ddim import DDIM
|
||||
from .pipeline_ddpm import DDPM
|
||||
from .pipeline_grad_tts import GradTTS
|
||||
|
||||
|
||||
try:
|
||||
|
||||
435
src/diffusers/pipelines/grad_tts_utils.py
Normal file
435
src/diffusers/pipelines/grad_tts_utils.py
Normal file
@@ -0,0 +1,435 @@
|
||||
# tokenizer
|
||||
|
||||
import os
|
||||
import re
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
try:
|
||||
from transformers import PreTrainedTokenizer
|
||||
except:
|
||||
print("transformers is not installed")
|
||||
|
||||
try:
|
||||
from unidecode import unidecode
|
||||
except:
|
||||
print("unidecode is not installed")
|
||||
pass
|
||||
|
||||
try:
|
||||
import inflect
|
||||
except:
|
||||
print("inflect is not installed")
|
||||
pass
|
||||
|
||||
|
||||
valid_symbols = [
|
||||
"AA",
|
||||
"AA0",
|
||||
"AA1",
|
||||
"AA2",
|
||||
"AE",
|
||||
"AE0",
|
||||
"AE1",
|
||||
"AE2",
|
||||
"AH",
|
||||
"AH0",
|
||||
"AH1",
|
||||
"AH2",
|
||||
"AO",
|
||||
"AO0",
|
||||
"AO1",
|
||||
"AO2",
|
||||
"AW",
|
||||
"AW0",
|
||||
"AW1",
|
||||
"AW2",
|
||||
"AY",
|
||||
"AY0",
|
||||
"AY1",
|
||||
"AY2",
|
||||
"B",
|
||||
"CH",
|
||||
"D",
|
||||
"DH",
|
||||
"EH",
|
||||
"EH0",
|
||||
"EH1",
|
||||
"EH2",
|
||||
"ER",
|
||||
"ER0",
|
||||
"ER1",
|
||||
"ER2",
|
||||
"EY",
|
||||
"EY0",
|
||||
"EY1",
|
||||
"EY2",
|
||||
"F",
|
||||
"G",
|
||||
"HH",
|
||||
"IH",
|
||||
"IH0",
|
||||
"IH1",
|
||||
"IH2",
|
||||
"IY",
|
||||
"IY0",
|
||||
"IY1",
|
||||
"IY2",
|
||||
"JH",
|
||||
"K",
|
||||
"L",
|
||||
"M",
|
||||
"N",
|
||||
"NG",
|
||||
"OW",
|
||||
"OW0",
|
||||
"OW1",
|
||||
"OW2",
|
||||
"OY",
|
||||
"OY0",
|
||||
"OY1",
|
||||
"OY2",
|
||||
"P",
|
||||
"R",
|
||||
"S",
|
||||
"SH",
|
||||
"T",
|
||||
"TH",
|
||||
"UH",
|
||||
"UH0",
|
||||
"UH1",
|
||||
"UH2",
|
||||
"UW",
|
||||
"UW0",
|
||||
"UW1",
|
||||
"UW2",
|
||||
"V",
|
||||
"W",
|
||||
"Y",
|
||||
"Z",
|
||||
"ZH",
|
||||
]
|
||||
|
||||
_valid_symbol_set = set(valid_symbols)
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
# Adds blank symbol
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
||||
|
||||
|
||||
class CMUDict:
|
||||
def __init__(self, file_or_path, keep_ambiguous=True):
|
||||
if isinstance(file_or_path, str):
|
||||
with open(file_or_path, encoding="latin-1") as f:
|
||||
entries = _parse_cmudict(f)
|
||||
else:
|
||||
entries = _parse_cmudict(file_or_path)
|
||||
if not keep_ambiguous:
|
||||
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
||||
self._entries = entries
|
||||
|
||||
def __len__(self):
|
||||
return len(self._entries)
|
||||
|
||||
def lookup(self, word):
|
||||
return self._entries.get(word.upper())
|
||||
|
||||
|
||||
_alt_re = re.compile(r"\([0-9]+\)")
|
||||
|
||||
|
||||
def _parse_cmudict(file):
|
||||
cmudict = {}
|
||||
for line in file:
|
||||
if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
||||
parts = line.split(" ")
|
||||
word = re.sub(_alt_re, "", parts[0])
|
||||
pronunciation = _get_pronunciation(parts[1])
|
||||
if pronunciation:
|
||||
if word in cmudict:
|
||||
cmudict[word].append(pronunciation)
|
||||
else:
|
||||
cmudict[word] = [pronunciation]
|
||||
return cmudict
|
||||
|
||||
|
||||
def _get_pronunciation(s):
|
||||
parts = s.strip().split(" ")
|
||||
for part in parts:
|
||||
if part not in _valid_symbol_set:
|
||||
return None
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
_abbreviations = [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
return text
|
||||
|
||||
|
||||
def expand_numbers(text):
|
||||
return normalize_numbers(text)
|
||||
|
||||
|
||||
def lowercase(text):
|
||||
return text.lower()
|
||||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, " ", text)
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
return unidecode(text)
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
||||
_pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
|
||||
_dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
|
||||
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||
_number_re = re.compile(r"[0-9]+")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(",", "")
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace(".", " point ")
|
||||
|
||||
|
||||
def _expand_dollars(m):
|
||||
match = m.group(1)
|
||||
parts = match.split(".")
|
||||
if len(parts) > 2:
|
||||
return match + " dollars"
|
||||
dollars = int(parts[0]) if parts[0] else 0
|
||||
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if dollars and cents:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
|
||||
elif dollars:
|
||||
dollar_unit = "dollar" if dollars == 1 else "dollars"
|
||||
return "%s %s" % (dollars, dollar_unit)
|
||||
elif cents:
|
||||
cent_unit = "cent" if cents == 1 else "cents"
|
||||
return "%s %s" % (cents, cent_unit)
|
||||
else:
|
||||
return "zero dollars"
|
||||
|
||||
|
||||
def _expand_ordinal(m):
|
||||
return _inflect.number_to_words(m.group(0))
|
||||
|
||||
|
||||
def _expand_number(m):
|
||||
num = int(m.group(0))
|
||||
if num > 1000 and num < 3000:
|
||||
if num == 2000:
|
||||
return "two thousand"
|
||||
elif num > 2000 and num < 2010:
|
||||
return "two thousand " + _inflect.number_to_words(num % 100)
|
||||
elif num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + " hundred"
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||
else:
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
text = re.sub(_comma_number_re, _remove_commas, text)
|
||||
text = re.sub(_pounds_re, r"\1 pounds", text)
|
||||
text = re.sub(_dollars_re, _expand_dollars, text)
|
||||
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
||||
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
||||
text = re.sub(_number_re, _expand_number, text)
|
||||
return text
|
||||
|
||||
|
||||
""" from https://github.com/keithito/tacotron """
|
||||
|
||||
|
||||
_pad = "_"
|
||||
_punctuation = "!'(),.:;? "
|
||||
_special = "-"
|
||||
_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness:
|
||||
_arpabet = ["@" + s for s in valid_symbols]
|
||||
|
||||
# Export all symbols:
|
||||
symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet
|
||||
|
||||
|
||||
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
||||
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
||||
|
||||
_curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
||||
|
||||
|
||||
def get_arpabet(word, dictionary):
|
||||
word_arpabet = dictionary.lookup(word)
|
||||
if word_arpabet is not None:
|
||||
return "{" + word_arpabet[0] + "}"
|
||||
else:
|
||||
return word
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None):
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
dictionary: arpabet class with arpabet dictionary
|
||||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
"""
|
||||
sequence = []
|
||||
space = _symbols_to_sequence(" ")
|
||||
# Check for curly braces and treat their contents as ARPAbet:
|
||||
while len(text):
|
||||
m = _curly_re.match(text)
|
||||
if not m:
|
||||
clean_text = _clean_text(text, cleaner_names)
|
||||
if dictionary is not None:
|
||||
clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")]
|
||||
for i in range(len(clean_text)):
|
||||
t = clean_text[i]
|
||||
if t.startswith("{"):
|
||||
sequence += _arpabet_to_sequence(t[1:-1])
|
||||
else:
|
||||
sequence += _symbols_to_sequence(t)
|
||||
sequence += space
|
||||
else:
|
||||
sequence += _symbols_to_sequence(clean_text)
|
||||
break
|
||||
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
# remove trailing space
|
||||
if dictionary is not None:
|
||||
sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence):
|
||||
"""Converts a sequence of IDs back to a string"""
|
||||
result = ""
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == "@":
|
||||
s = "{%s}" % s[1:]
|
||||
result += s
|
||||
return result.replace("}{", " ")
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for cleaner in cleaner_names:
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
||||
def _symbols_to_sequence(symbols):
|
||||
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
||||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return _symbols_to_sequence(["@" + s for s in text.split()])
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s != "_" and s != "~"
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"dict_file": "dict_file.txt",
|
||||
}
|
||||
|
||||
|
||||
class GradTTSTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
|
||||
def __init__(self, dict_file, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.cmu = CMUDict(dict_file)
|
||||
self.dict_file = dict_file
|
||||
|
||||
def __call__(self, text):
|
||||
x = torch.LongTensor(intersperse(text_to_sequence(text, dictionary=self.cmu), len(symbols)))[None]
|
||||
x_lengths = torch.LongTensor([x.shape[-1]])
|
||||
return x, x_lengths
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix=None):
|
||||
dict_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["dict_file"]
|
||||
)
|
||||
|
||||
copyfile(self.dict_file, dict_file)
|
||||
|
||||
return (dict_file,)
|
||||
@@ -5,9 +5,13 @@ import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import tqdm
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
|
||||
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
@@ -414,3 +418,63 @@ class TextEncoder(ModelMixin, ConfigMixin):
|
||||
logw = self.proj_w(x_dp, x_mask)
|
||||
|
||||
return mu, logw, x_mask
|
||||
|
||||
|
||||
class GradTTS(DiffusionPipeline):
|
||||
def __init__(self, unet, text_encoder, noise_scheduler, tokenizer):
|
||||
super().__init__()
|
||||
noise_scheduler = noise_scheduler.set_format("pt")
|
||||
self.register_modules(
|
||||
unet=unet, text_encoder=text_encoder, noise_scheduler=noise_scheduler, tokenizer=tokenizer
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, text, num_inference_steps=50, temperature=1.3, length_scale=0.91, speaker_id=15, torch_device=None
|
||||
):
|
||||
if torch_device is None:
|
||||
torch_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
self.unet.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
|
||||
x, x_lengths = self.tokenizer(text)
|
||||
x = x.to(torch_device)
|
||||
x_lengths = x_lengths.to(torch_device)
|
||||
|
||||
if speaker_id is not None:
|
||||
speaker_id = torch.LongTensor([speaker_id]).to(torch_device)
|
||||
|
||||
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
|
||||
mu_x, logw, x_mask = self.text_encoder(x, x_lengths)
|
||||
|
||||
w = torch.exp(logw) * x_mask
|
||||
w_ceil = torch.ceil(w) * length_scale
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_max_length = int(y_lengths.max())
|
||||
y_max_length_ = fix_len_compatibility(y_max_length)
|
||||
|
||||
# Using obtained durations `w` construct alignment map `attn`
|
||||
y_mask = sequence_mask(y_lengths, y_max_length_).unsqueeze(1).to(x_mask.dtype)
|
||||
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
||||
|
||||
# Align encoded text and get mu_y
|
||||
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
|
||||
mu_y = mu_y.transpose(1, 2)
|
||||
|
||||
# Sample latent representation from terminal distribution N(mu_y, I)
|
||||
z = mu_y + torch.randn_like(mu_y, device=mu_y.device) / temperature
|
||||
|
||||
xt = z * y_mask
|
||||
h = 1.0 / num_inference_steps
|
||||
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
|
||||
t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
|
||||
time = t.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
residual = self.unet(xt, y_mask, mu_y, t, speaker_id)
|
||||
|
||||
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
|
||||
xt = xt * y_mask
|
||||
|
||||
return xt[:, :, :y_max_length]
|
||||
|
||||
@@ -19,5 +19,6 @@
|
||||
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddpm import DDPMScheduler
|
||||
from .scheduling_grad_tts import GradTTSScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
49
src/diffusers/schedulers/scheduling_grad_tts.py
Normal file
49
src/diffusers/schedulers/scheduling_grad_tts.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
class GradTTSScheduler(SchedulerMixin, ConfigMixin):
|
||||
def __init__(
|
||||
self,
|
||||
timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
tensor_format="np",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
timesteps=timesteps,
|
||||
beta_start=beta_start,
|
||||
beta_end=beta_end,
|
||||
)
|
||||
self.timesteps = int(timesteps)
|
||||
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def sample_noise(self, timestep):
|
||||
noise = self.beta_start + (self.beta_end - self.beta_start) * timestep
|
||||
return noise
|
||||
|
||||
def step(self, xt, residual, mu, h, timestep):
|
||||
noise_t = self.sample_noise(timestep)
|
||||
dxt = 0.5 * (mu - xt - residual)
|
||||
dxt = dxt * noise_t * h
|
||||
xt = xt - dxt
|
||||
return xt
|
||||
|
||||
def __len__(self):
|
||||
return self.timesteps
|
||||
Reference in New Issue
Block a user