1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[tests] Add test slices for Hunyuan Video (#11954)

update
This commit is contained in:
Aryan
2025-07-21 07:52:16 +05:30
committed by GitHub
parent cde02b061b
commit 67a8ec8bf5
4 changed files with 45 additions and 20 deletions

View File

@@ -229,12 +229,19 @@ class HunyuanVideoImageToVideoPipelineFastTests(
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
# NOTE: The expected video has 4 lesser frames because they are dropped in the pipeline
self.assertEqual(generated_video.shape, (5, 3, 16, 16))
expected_video = torch.randn(5, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
# fmt: off
expected_slice = torch.tensor([0.444, 0.479, 0.4485, 0.5752, 0.3539, 0.1548, 0.2706, 0.3593, 0.5323, 0.6635, 0.6795, 0.5255, 0.5091, 0.345, 0.4276, 0.4128])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(
torch.allclose(generated_slice, expected_slice, atol=1e-3),
"The generated video does not match the expected slice.",
)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)

View File

@@ -192,11 +192,18 @@ class HunyuanSkyreelsImageToVideoPipelineFastTests(
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
# fmt: off
expected_slice = torch.tensor([0.5832, 0.5498, 0.4839, 0.4744, 0.4515, 0.4832, 0.496, 0.563, 0.5918, 0.5979, 0.5101, 0.6168, 0.6613, 0.536, 0.55, 0.5775])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(
torch.allclose(generated_slice, expected_slice, atol=1e-3),
"The generated video does not match the expected slice.",
)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)

View File

@@ -26,10 +26,7 @@ from diffusers import (
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import (
FasterCacheTesterMixin,
@@ -206,11 +203,18 @@ class HunyuanVideoPipelineFastTests(
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
expected_video = torch.randn(9, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
# fmt: off
expected_slice = torch.tensor([0.3946, 0.4649, 0.3196, 0.4569, 0.3312, 0.3687, 0.3216, 0.3972, 0.4469, 0.3888, 0.3929, 0.3802, 0.3479, 0.3888, 0.3825, 0.3542])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(
torch.allclose(generated_slice, expected_slice, atol=1e-3),
"The generated video does not match the expected slice.",
)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)

View File

@@ -227,11 +227,18 @@ class HunyuanVideoFramepackPipelineFastTests(
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (13, 3, 32, 32))
expected_video = torch.randn(13, 3, 32, 32)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
# fmt: off
expected_slice = torch.tensor([0.363, 0.3384, 0.3426, 0.3512, 0.3372, 0.3276, 0.417, 0.4061, 0.5221, 0.467, 0.4813, 0.4556, 0.4107, 0.3945, 0.4049, 0.4551])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(
torch.allclose(generated_slice, expected_slice, atol=1e-3),
"The generated video does not match the expected slice.",
)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)