From d87134ada459a843ab75c8e2f7bddf71902763e5 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 21 Jul 2025 07:52:44 +0530 Subject: [PATCH] [tests] Add test slices for Cosmos (#11955) * test * try fix --- tests/pipelines/cosmos/test_cosmos.py | 12 ++++++++---- tests/pipelines/cosmos/test_cosmos2_text2image.py | 12 ++++++++---- tests/pipelines/cosmos/test_cosmos2_video2world.py | 12 ++++++++---- tests/pipelines/cosmos/test_cosmos_video2world.py | 12 ++++++++---- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/tests/pipelines/cosmos/test_cosmos.py b/tests/pipelines/cosmos/test_cosmos.py index 0c1024a9a9..4d3202f785 100644 --- a/tests/pipelines/cosmos/test_cosmos.py +++ b/tests/pipelines/cosmos/test_cosmos.py @@ -153,11 +153,15 @@ class CosmosTextToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase) inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 32, 32)) - expected_video = torch.randn(9, 3, 32, 32) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.0, 0.9686, 0.8549, 0.8078, 0.0, 0.8431, 1.0, 0.4863, 0.7098, 0.1098, 0.8157, 0.4235, 0.6353, 0.2549, 0.5137, 0.5333]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) def test_callback_inputs(self): sig = inspect.signature(self.pipeline_class.__call__) diff --git a/tests/pipelines/cosmos/test_cosmos2_text2image.py b/tests/pipelines/cosmos/test_cosmos2_text2image.py index 386bf161a0..cc2fcec641 100644 --- a/tests/pipelines/cosmos/test_cosmos2_text2image.py +++ b/tests/pipelines/cosmos/test_cosmos2_text2image.py @@ -140,11 +140,15 @@ class Cosmos2TextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images generated_image = image[0] - self.assertEqual(generated_image.shape, (3, 32, 32)) - expected_video = torch.randn(3, 32, 32) - max_diff = np.abs(generated_image - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.4784, 0.4784, 0.4784, 0.4784, 0.4784, 0.4902, 0.4588, 0.5333]) + # fmt: on + + generated_slice = generated_image.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) def test_callback_inputs(self): sig = inspect.signature(self.pipeline_class.__call__) diff --git a/tests/pipelines/cosmos/test_cosmos2_video2world.py b/tests/pipelines/cosmos/test_cosmos2_video2world.py index 421e3a1ad3..b23c8aed17 100644 --- a/tests/pipelines/cosmos/test_cosmos2_video2world.py +++ b/tests/pipelines/cosmos/test_cosmos2_video2world.py @@ -147,11 +147,15 @@ class Cosmos2VideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCas inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 32, 32)) - expected_video = torch.randn(9, 3, 32, 32) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.5098, 0.5137, 0.5176, 0.5098, 0.5255, 0.5412, 0.5098, 0.5059]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) def test_components_function(self): init_components = self.get_dummy_components() diff --git a/tests/pipelines/cosmos/test_cosmos_video2world.py b/tests/pipelines/cosmos/test_cosmos_video2world.py index 2b893e9970..d0dba5575b 100644 --- a/tests/pipelines/cosmos/test_cosmos_video2world.py +++ b/tests/pipelines/cosmos/test_cosmos_video2world.py @@ -159,11 +159,15 @@ class CosmosVideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase inputs = self.get_dummy_inputs(device) video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (9, 3, 32, 32)) - expected_video = torch.randn(9, 3, 32, 32) - max_diff = np.abs(generated_video - expected_video).max() - self.assertLessEqual(max_diff, 1e10) + + # fmt: off + expected_slice = torch.tensor([0.0, 0.8275, 0.7529, 0.7294, 0.0, 0.6, 1.0, 0.3804, 0.6667, 0.0863, 0.8784, 0.5922, 0.6627, 0.2784, 0.5725, 0.7765]) + # fmt: on + + generated_slice = generated_video.flatten() + generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) + self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) def test_components_function(self): init_components = self.get_dummy_components()