1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Improve dummy inputs and add test for LTX 2 transformer consistency

This commit is contained in:
Daniel Gu
2025-12-16 10:44:02 +01:00
parent 57a8b9c330
commit a7bc052e89

View File

@@ -17,7 +17,7 @@ import unittest
import torch
from diffusers import LTX2VideoTransformer3DModel
from diffusers import LTX2VideoTransformer3DModel, attention_backend
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
@@ -35,16 +35,15 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
def dummy_input(self):
# Common
batch_size = 2
# NOTE: at 25 FPS, using the same num_frames for hidden_states and audio_hidden_states will result in video
# and audio of equal duration
num_frames = 2
# Video
num_frames = 2
num_channels = 4
height = 16
width = 16
# Audio
audio_num_frames = 9
audio_num_channels = 2
num_mel_bins = 2
@@ -54,12 +53,12 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
audio_hidden_states = torch.randn(
(batch_size, num_frames, audio_num_channels * num_mel_bins)
(batch_size, audio_num_frames, audio_num_channels * num_mel_bins)
).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
timestep = torch.rand((batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
@@ -71,6 +70,7 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
"num_frames": num_frames,
"height": height,
"width": width,
"audio_num_frames": audio_num_frames,
"fps": 25.0,
}
@@ -107,6 +107,116 @@ class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
expected_set = {"LTX2VideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_ltx2_consistency(self, seed=0, dtype=torch.float32):
torch.manual_seed(seed)
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
# Calculate dummy inputs in a custom manner to ensure compatibility with original code
batch_size = 2
num_frames = 9
latent_frames = 2
text_embedding_dim = 16
text_seq_len = 16
fps = 25.0
sampling_rate = 16000.0
hop_length = 160.0
sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu")
timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
num_channels = 4
latent_height = 4
latent_width = 4
hidden_states = torch.randn(
(batch_size, num_channels, latent_frames, latent_height, latent_width),
generator=torch.manual_seed(seed),
dtype=dtype,
device="cpu",
)
# Patchify video latents (with patch_size (1, 1, 1))
hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1)
hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
encoder_hidden_states = torch.randn(
(batch_size, text_seq_len, text_embedding_dim),
generator=torch.manual_seed(seed),
dtype=dtype,
device="cpu",
)
audio_num_channels = 2
num_mel_bins = 2
latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps))
audio_hidden_states = torch.randn(
(batch_size, audio_num_channels, latent_length, num_mel_bins),
generator=torch.manual_seed(seed),
dtype=dtype,
device="cpu",
)
# Patchify audio latents
audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3)
audio_encoder_hidden_states = torch.randn(
(batch_size, text_seq_len, text_embedding_dim),
generator=torch.manual_seed(seed),
dtype=dtype,
device="cpu",
)
inputs_dict = {
"hidden_states": hidden_states.to(device=torch_device),
"audio_hidden_states": audio_hidden_states.to(device=torch_device),
"encoder_hidden_states": encoder_hidden_states.to(device=torch_device),
"audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device),
"timestep": timestep,
"num_frames": latent_frames,
"height": latent_height,
"width": latent_width,
"audio_num_frames": num_frames,
"fps": 25.0,
}
model = self.model_class.from_pretrained(
"diffusers-internal-dev/dummy-ltx2",
subfolder="transformer",
device_map="cpu",
)
# torch.manual_seed(seed)
# model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with attention_backend("native"):
with torch.no_grad():
output = model(**inputs_dict)
video_output, audio_output = output.to_tuple()
self.assertIsNotNone(video_output)
self.assertIsNotNone(audio_output)
# input & output have to have the same shape
video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels)
self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match")
audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins)
self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match")
# Check against expected slice
# fmt: off
video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676])
audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692])
# fmt: on
video_output_flat = video_output.cpu().flatten().float()
video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
print(f"Video Expected Slice: {video_expected_slice}")
print(f"Video Generated Slice: {video_generated_slice}")
self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
audio_output_flat = audio_output.cpu().flatten().float()
audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
print(f"Audio Expected Slice: {audio_expected_slice}")
print(f"Audio Generated Slice: {audio_generated_slice}")
self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))
class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTX2VideoTransformer3DModel