mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Add slow test on full checkpoint for joint mode and correct expected image slices/text prefixes.
This commit is contained in:
@@ -422,6 +422,29 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_unidiffuser_default_joint(self):
|
||||
pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-diffusers")
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
# inputs = self.get_dummy_inputs(device)
|
||||
inputs = self.get_inputs()
|
||||
# Delete prompt and image for joint inference.
|
||||
del inputs["prompt"]
|
||||
del inputs["image"]
|
||||
sample = pipe(**inputs)
|
||||
image = sample.images
|
||||
text = sample.text
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_img_slice = np.array([0.8887, 0.8926, 0.8672, 0.8984, 0.8867, 0.8564, 0.9043, 0.8887, 0.8657])
|
||||
assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
|
||||
|
||||
expected_text_prefix = "Pink pink "
|
||||
assert text[0][:10] == expected_text_prefix
|
||||
|
||||
def test_unidiffuser_default_text2img(self):
|
||||
pipe = UniDiffuserPipeline.from_pretrained("dg845/unidiffuser-diffusers")
|
||||
@@ -436,8 +459,7 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
# TODO: get correct image slice
|
||||
expected_slice = np.array([0.3965, 0.4568, 0.4495, 0.4590, 0.4465, 0.4690, 0.5454, 0.5093, 0.4321])
|
||||
expected_slice = np.array([0.4702, 0.4666, 0.4446, 0.4829, 0.4468, 0.4565, 0.4663, 0.4956, 0.4277])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_unidiffuser_default_img2text(self):
|
||||
@@ -451,6 +473,5 @@ class UniDiffuserPipelineSlowTests(unittest.TestCase):
|
||||
sample = pipe(**inputs)
|
||||
text = sample.images
|
||||
|
||||
# TODO: get correct text prefix
|
||||
expected_text_prefix = " no no no "
|
||||
expected_text_prefix = "Astronaut "
|
||||
assert text[0][:10] == expected_text_prefix
|
||||
|
||||
Reference in New Issue
Block a user