mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -153,11 +153,15 @@ class CosmosTextToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
expected_video = torch.randn(9, 3, 32, 32)
|
||||
max_diff = np.abs(generated_video - expected_video).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.0, 0.9686, 0.8549, 0.8078, 0.0, 0.8431, 1.0, 0.4863, 0.7098, 0.1098, 0.8157, 0.4235, 0.6353, 0.2549, 0.5137, 0.5333])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_video.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
|
||||
@@ -140,11 +140,15 @@ class Cosmos2TextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
expected_video = torch.randn(3, 32, 32)
|
||||
max_diff = np.abs(generated_image - expected_video).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.4784, 0.4784, 0.4784, 0.4784, 0.4784, 0.4902, 0.4588, 0.5333])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_image.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||
|
||||
def test_callback_inputs(self):
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
|
||||
@@ -147,11 +147,15 @@ class Cosmos2VideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
expected_video = torch.randn(9, 3, 32, 32)
|
||||
max_diff = np.abs(generated_video - expected_video).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.5098, 0.5137, 0.5176, 0.5098, 0.5255, 0.5412, 0.5098, 0.5059])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_video.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
|
||||
@@ -159,11 +159,15 @@ class CosmosVideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
video = pipe(**inputs).frames
|
||||
generated_video = video[0]
|
||||
|
||||
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
|
||||
expected_video = torch.randn(9, 3, 32, 32)
|
||||
max_diff = np.abs(generated_video - expected_video).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.0, 0.8275, 0.7529, 0.7294, 0.0, 0.6, 1.0, 0.3804, 0.6667, 0.0863, 0.8784, 0.5922, 0.6627, 0.2784, 0.5725, 0.7765])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_video.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||
|
||||
def test_components_function(self):
|
||||
init_components = self.get_dummy_components()
|
||||
|
||||
Reference in New Issue
Block a user