1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add fast test for checking UniDiffuser-v1 sampling.

This commit is contained in:
Daniel Gu
2023-05-08 17:50:46 -07:00
parent 54c495f175
commit 8dd7b0be2c

View File

@@ -294,6 +294,34 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_text_prefix = " no no no "
assert text[0][:10] == expected_text_prefix
def test_unidiffuser_default_joint_v1(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-test-v1")
unidiffuser_pipe = unidiffuser_pipe.to(device)
unidiffuser_pipe.set_progress_bar_config(disable=None)
# Set mode to 'joint'
unidiffuser_pipe.set_joint_mode()
assert unidiffuser_pipe.mode == "joint"
# inputs = self.get_dummy_inputs(device)
inputs = self.get_dummy_inputs_with_latents(device)
# Delete prompt and image for joint inference.
del inputs["prompt"]
del inputs["image"]
inputs["data_type"] = 1
sample = unidiffuser_pipe(**inputs)
image = sample.images
text = sample.text
assert image.shape == (1, 32, 32, 3)
image_slice = image[0, -3:, -3:, -1]
expected_img_slice = np.array([0.5759, 0.6270, 0.6571, 0.4966, 0.4639, 0.5663, 0.5254, 0.5068, 0.5715])
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
expected_text_prefix = " no no no "
assert text[0][:10] == expected_text_prefix
@slow
@require_torch_gpu