mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add UNetGradTTSModelTests
This commit is contained in:
@@ -190,7 +190,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
self.final_block = Block(dim, dim)
|
||||
self.final_conv = torch.nn.Conv2d(dim, 1, 1)
|
||||
|
||||
def forward(self, x, mask, mu, t, spk=None):
|
||||
def forward(self, x, timesteps, mu, mask, spk=None):
|
||||
if self.n_spks > 1:
|
||||
# Get speaker embedding
|
||||
spk = self.spk_emb(spk)
|
||||
@@ -198,7 +198,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
|
||||
if not isinstance(spk, type(None)):
|
||||
s = self.spk_mlp(spk)
|
||||
|
||||
t = self.time_pos_emb(t, scale=self.pe_scale)
|
||||
t = self.time_pos_emb(timesteps, scale=self.pe_scale)
|
||||
t = self.mlp(t)
|
||||
|
||||
if self.n_spks < 2:
|
||||
|
||||
@@ -472,7 +472,7 @@ class GradTTS(DiffusionPipeline):
|
||||
t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
|
||||
time = t.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
residual = self.unet(xt, y_mask, mu_y, t, speaker_id)
|
||||
residual = self.unet(xt, t, mu_y, y_mask, speaker_id)
|
||||
|
||||
xt = self.noise_scheduler.step(xt, residual, mu_y, h, time)
|
||||
xt = xt * y_mask
|
||||
|
||||
@@ -35,6 +35,7 @@ from diffusers import (
|
||||
PNDMScheduler,
|
||||
UNetModel,
|
||||
UNetLDMModel,
|
||||
UNetGradTTSModel,
|
||||
)
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
@@ -410,6 +411,78 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNetGradTTSModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 32
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
condition = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
mask = floats_tensor((batch_size, 1, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask}
|
||||
|
||||
@property
|
||||
def get_input_shape(self):
|
||||
return (4, 32, 16)
|
||||
|
||||
@property
|
||||
def get_output_shape(self):
|
||||
return (4, 32, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"dim": 64,
|
||||
"groups": 4,
|
||||
"dim_mults": (1, 2),
|
||||
"n_feats": 32,
|
||||
"pe_scale": 1000,
|
||||
"n_spks": 1,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
def test_output_pretrained(self):
|
||||
model = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy")
|
||||
model.eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_features = model.config.n_feats
|
||||
seq_len = 16
|
||||
noise = torch.randn((1, num_features, seq_len))
|
||||
condition = torch.randn((1, num_features, seq_len))
|
||||
mask = torch.randn((1, 1, seq_len))
|
||||
time_step = torch.tensor([10])
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step, condition, mask)
|
||||
|
||||
output_slice = output[0, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617])
|
||||
# fmt: on
|
||||
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3))
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
# 1. Load models
|
||||
|
||||
Reference in New Issue
Block a user