mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add fast test for checking UniDiffuser-v1 sampling.
This commit is contained in:
@@ -294,6 +294,34 @@ class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
expected_text_prefix = " no no no "
|
||||
assert text[0][:10] == expected_text_prefix
|
||||
|
||||
def test_unidiffuser_default_joint_v1(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-test-v1")
|
||||
unidiffuser_pipe = unidiffuser_pipe.to(device)
|
||||
unidiffuser_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Set mode to 'joint'
|
||||
unidiffuser_pipe.set_joint_mode()
|
||||
assert unidiffuser_pipe.mode == "joint"
|
||||
|
||||
# inputs = self.get_dummy_inputs(device)
|
||||
inputs = self.get_dummy_inputs_with_latents(device)
|
||||
# Delete prompt and image for joint inference.
|
||||
del inputs["prompt"]
|
||||
del inputs["image"]
|
||||
inputs["data_type"] = 1
|
||||
sample = unidiffuser_pipe(**inputs)
|
||||
image = sample.images
|
||||
text = sample.text
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_img_slice = np.array([0.5759, 0.6270, 0.6571, 0.4966, 0.4639, 0.5663, 0.5254, 0.5068, 0.5715])
|
||||
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
|
||||
|
||||
expected_text_prefix = " no no no "
|
||||
assert text[0][:10] == expected_text_prefix
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user