diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 84bd64d9d0..a2bdd951e4 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -190,7 +190,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): self.final_block = Block(dim, dim) self.final_conv = torch.nn.Conv2d(dim, 1, 1) - def forward(self, x, mask, mu, t, spk=None): + def forward(self, x, timesteps, mu, mask, spk=None): if self.n_spks > 1: # Get speaker embedding spk = self.spk_emb(spk) @@ -198,7 +198,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin): if not isinstance(spk, type(None)): s = self.spk_mlp(spk) - t = self.time_pos_emb(t, scale=self.pe_scale) + t = self.time_pos_emb(timesteps, scale=self.pe_scale) t = self.mlp(t) if self.n_spks < 2: diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index c9fad2192c..4201124923 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -472,7 +472,7 @@ class GradTTS(DiffusionPipeline): t = (1.0 - (t + 0.5) * h) * torch.ones(z.shape[0], dtype=z.dtype, device=z.device) time = t.unsqueeze(-1).unsqueeze(-1) - residual = self.unet(xt, y_mask, mu_y, t, speaker_id) + residual = self.unet(xt, t, mu_y, y_mask, speaker_id) xt = self.noise_scheduler.step(xt, residual, mu_y, h, time) xt = xt * y_mask diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index d6e90ea55a..7b2530a75e 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -35,6 +35,7 @@ from diffusers import ( PNDMScheduler, UNetModel, UNetLDMModel, + UNetGradTTSModel, ) from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline @@ -410,6 +411,78 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) +class UNetGradTTSModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNetGradTTSModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 32 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + condition = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + mask = floats_tensor((batch_size, 1, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"x": noise, "timesteps": time_step, "mu": condition, "mask": mask} + + @property + def get_input_shape(self): + return (4, 32, 16) + + @property + def get_output_shape(self): + return (4, 32, 16) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "dim": 64, + "groups": 4, + "dim_mults": (1, 2), + "n_feats": 32, + "pe_scale": 1000, + "n_spks": 1, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_from_pretrained_hub(self): + model, loading_info = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy", output_loading_info=True) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = UNetGradTTSModel.from_pretrained("fusing/unet-grad-tts-dummy") + model.eval() + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = model.config.n_feats + seq_len = 16 + noise = torch.randn((1, num_features, seq_len)) + condition = torch.randn((1, num_features, seq_len)) + mask = torch.randn((1, 1, seq_len)) + time_step = torch.tensor([10]) + + with torch.no_grad(): + output = model(noise, time_step, condition, mask) + + output_slice = output[0, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-0.0690, -0.0531, 0.0633, -0.0660, -0.0541, 0.0650, -0.0656, -0.0555, 0.0617]) + # fmt: on + + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + class PipelineTesterMixin(unittest.TestCase): def test_from_pretrained_save_pretrained(self): # 1. Load models