From a7bc052e899936396dfcd08b0a5a88abe2088b5f Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 16 Dec 2025 10:44:02 +0100 Subject: [PATCH] Improve dummy inputs and add test for LTX 2 transformer consistency --- .../test_models_transformer_ltx2.py | 122 +++++++++++++++++- 1 file changed, 116 insertions(+), 6 deletions(-) diff --git a/tests/models/transformers/test_models_transformer_ltx2.py b/tests/models/transformers/test_models_transformer_ltx2.py index c382a63eaa..0bf08f161d 100644 --- a/tests/models/transformers/test_models_transformer_ltx2.py +++ b/tests/models/transformers/test_models_transformer_ltx2.py @@ -17,7 +17,7 @@ import unittest import torch -from diffusers import LTX2VideoTransformer3DModel +from diffusers import LTX2VideoTransformer3DModel, attention_backend from ...testing_utils import enable_full_determinism, torch_device from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin @@ -35,16 +35,15 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): def dummy_input(self): # Common batch_size = 2 - # NOTE: at 25 FPS, using the same num_frames for hidden_states and audio_hidden_states will result in video - # and audio of equal duration - num_frames = 2 # Video + num_frames = 2 num_channels = 4 height = 16 width = 16 # Audio + audio_num_frames = 9 audio_num_channels = 2 num_mel_bins = 2 @@ -54,12 +53,12 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) audio_hidden_states = torch.randn( - (batch_size, num_frames, audio_num_channels * num_mel_bins) + (batch_size, audio_num_frames, audio_num_channels * num_mel_bins) ).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) - timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + timestep = torch.rand((batch_size,)).to(torch_device) return { "hidden_states": hidden_states, @@ -71,6 +70,7 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): "num_frames": num_frames, "height": height, "width": width, + "audio_num_frames": audio_num_frames, "fps": 25.0, } @@ -107,6 +107,116 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): expected_set = {"LTX2VideoTransformer3DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) + def test_ltx2_consistency(self, seed=0, dtype=torch.float32): + torch.manual_seed(seed) + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + # Calculate dummy inputs in a custom manner to ensure compatibility with original code + batch_size = 2 + num_frames = 9 + latent_frames = 2 + text_embedding_dim = 16 + text_seq_len = 16 + fps = 25.0 + sampling_rate = 16000.0 + hop_length = 160.0 + + sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") + timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device) + + num_channels = 4 + latent_height = 4 + latent_width = 4 + hidden_states = torch.randn( + (batch_size, num_channels, latent_frames, latent_height, latent_width), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + # Patchify video latents (with patch_size (1, 1, 1)) + hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1) + hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) + encoder_hidden_states = torch.randn( + (batch_size, text_seq_len, text_embedding_dim), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + + audio_num_channels = 2 + num_mel_bins = 2 + latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps)) + audio_hidden_states = torch.randn( + (batch_size, audio_num_channels, latent_length, num_mel_bins), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + # Patchify audio latents + audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3) + audio_encoder_hidden_states = torch.randn( + (batch_size, text_seq_len, text_embedding_dim), + generator=torch.manual_seed(seed), + dtype=dtype, + device="cpu", + ) + + inputs_dict = { + "hidden_states": hidden_states.to(device=torch_device), + "audio_hidden_states": audio_hidden_states.to(device=torch_device), + "encoder_hidden_states": encoder_hidden_states.to(device=torch_device), + "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device), + "timestep": timestep, + "num_frames": latent_frames, + "height": latent_height, + "width": latent_width, + "audio_num_frames": num_frames, + "fps": 25.0, + } + + model = self.model_class.from_pretrained( + "diffusers-internal-dev/dummy-ltx2", + subfolder="transformer", + device_map="cpu", + ) + # torch.manual_seed(seed) + # model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with attention_backend("native"): + with torch.no_grad(): + output = model(**inputs_dict) + + video_output, audio_output = output.to_tuple() + + self.assertIsNotNone(video_output) + self.assertIsNotNone(audio_output) + + # input & output have to have the same shape + video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels) + self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match") + audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins) + self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match") + + # Check against expected slice + # fmt: off + video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676]) + audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692]) + # fmt: on + + video_output_flat = video_output.cpu().flatten().float() + video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) + print(f"Video Expected Slice: {video_expected_slice}") + print(f"Video Generated Slice: {video_generated_slice}") + self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) + + audio_output_flat = audio_output.cpu().flatten().float() + audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) + print(f"Audio Expected Slice: {audio_expected_slice}") + print(f"Audio Generated Slice: {audio_generated_slice}") + self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4)) + class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): model_class = LTX2VideoTransformer3DModel