1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

refactor tts sampler a bit

This commit is contained in:
Patrick von Platen
2022-06-22 23:15:57 +02:00
parent 4fbf8c815e
commit f941fc9917
4 changed files with 30 additions and 17 deletions

View File

@@ -694,6 +694,7 @@ class CLIPTextModel(CLIPPreTrainedModel):
# END OF THE CLIP MODEL COPY-PASTE
#####################
def _extract_into_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.

View File

@@ -475,13 +475,15 @@ class GradTTSPipeline(DiffusionPipeline):
xt = z * y_mask
h = 1.0 / num_inference_steps
# (Patrick: TODO)
for t in tqdm.tqdm(range(num_inference_steps), total=num_inference_steps):
t_new = num_inference_steps - t - 1
t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1)
residual = self.unet(xt, t, mu_y, y_mask, speaker_id)
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
scheduler_residual = residual - mu_y + xt
xt = self.noise_scheduler.step(scheduler_residual, xt, t_new, num_inference_steps)
xt = xt * y_mask
return xt[:, :, :y_max_length]

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
@@ -19,29 +21,34 @@ from .scheduling_utils import SchedulerMixin
class GradTTSScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_start=0.05,
beta_end=20,
tensor_format="np",
):
super().__init__()
self.register_to_config(
timesteps=timesteps,
beta_start=beta_start,
beta_end=beta_end,
)
self.set_format(tensor_format=tensor_format)
self.betas = None
def sample_noise(self, timestep):
noise = self.beta_start + (self.beta_end - self.beta_start) * timestep
return noise
def get_timesteps(self, num_inference_steps):
return np.array([(t + 0.5) / num_inference_steps for t in range(num_inference_steps)])
def step(self, xt, residual, mu, h, timestep):
noise_t = self.sample_noise(timestep)
dxt = 0.5 * (mu - xt - residual)
dxt = dxt * noise_t * h
xt = xt - dxt
return xt
def set_betas(self, num_inference_steps):
timesteps = self.get_timesteps(num_inference_steps)
self.betas = np.array([self.beta_start + (self.beta_end - self.beta_start) * t for t in timesteps])
def __len__(self):
return len(self.config.timesteps)
def step(self, residual, sample, t, num_inference_steps):
# This is a VE scheduler from https://arxiv.org/pdf/2011.13456.pdf (see Algorithm 2 in Appendix)
if self.betas is None:
self.set_betas(num_inference_steps)
beta_t = self.betas[t]
beta_t_deriv = beta_t / num_inference_steps
sample_deriv = residual * beta_t_deriv / 2
sample = sample + sample_deriv
return sample

View File

@@ -31,6 +31,7 @@ from diffusers import (
GlideSuperResUNetModel,
GlideTextToImageUNetModel,
GradTTSPipeline,
GradTTSScheduler,
LatentDiffusionPipeline,
PNDMPipeline,
PNDMScheduler,
@@ -705,6 +706,8 @@ class PipelineTesterMixin(unittest.TestCase):
def test_grad_tts(self):
model_id = "fusing/grad-tts-libri-tts"
grad_tts = GradTTSPipeline.from_pretrained(model_id)
noise_scheduler = GradTTSScheduler()
grad_tts.noise_scheduler = noise_scheduler
text = "Hello world, I missed you so much."
generator = torch.manual_seed(0)