diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py index 281419469d..9c0585d501 100644 --- a/tests/pipelines/unidiffuser/test_unidiffuser.py +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -422,6 +422,29 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase): "output_type": "numpy", } return inputs + + def test_unidiffuser_default_joint(self): + pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-diffusers") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_inputs() + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + sample = pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.8887, 0.8926, 0.8672, 0.8984, 0.8867, 0.8564, 0.9043, 0.8887, 0.8657]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + expected_text_prefix = "Pink pink " + assert text[0][:10] == expected_text_prefix def test_unidiffuser_default_text2img(self): pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-diffusers") @@ -436,8 +459,7 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase): assert image.shape == (1, 512, 512, 3) image_slice = image[0, -3:, -3:, -1] - # TODO: get correct image slice - expected_slice = np.array([0.3965, 0.4568, 0.4495, 0.4590, 0.4465, 0.4690, 0.5454, 0.5093, 0.4321]) + expected_slice = np.array([0.4702, 0.4666, 0.4446, 0.4829, 0.4468, 0.4565, 0.4663, 0.4956, 0.4277]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_unidiffuser_default_img2text(self): @@ -451,6 +473,5 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase): sample = pipe(**inputs) text = sample.images - # TODO: get correct text prefix - expected_text_prefix = " no no no " + expected_text_prefix = "Astronaut " assert text[0][:10] == expected_text_prefix