From 8dd7b0be2c1235e4856c5c2ef48e8db8ffba19dc Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 8 May 2023 17:50:46 -0700 Subject: [PATCH] Add fast test for checking UniDiffuser-v1 sampling. --- .../pipelines/unidiffuser/test_unidiffuser.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py index 7b80860839..14ef684156 100644 --- a/tests/pipelines/unidiffuser/test_unidiffuser.py +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -294,6 +294,34 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): expected_text_prefix = " no no no " assert text[0][:10] == expected_text_prefix + def test_unidiffuser_default_joint_v1(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-test-v1") + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'joint' + unidiffuser_pipe.set_joint_mode() + assert unidiffuser_pipe.mode == "joint" + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + inputs["data_type"] = 1 + sample = unidiffuser_pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5759, 0.6270, 0.6571, 0.4966, 0.4639, 0.5663, 0.5254, 0.5068, 0.5715]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + @slow @require_torch_gpu