1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Update Flax TPU tests (#3069)

Update Flax TPU tests.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Pedro Cuenca
2023-04-12 15:17:36 +02:00
committed by Daniel Gu
parent 870169e08f
commit cbb18d8ba9

View File

@@ -78,11 +78,10 @@ class FlaxPipelineTests(unittest.TestCase):
assert images.shape == (num_samples, 1, 64, 64, 3)
if jax.device_count() == 8:
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3
assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.1514745) < 1e-3
assert np.abs(np.abs(images, dtype=np.float32).sum() - 49947.875) < 5e-1
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
assert len(images_pil) == num_samples
def test_stable_diffusion_v1_4(self):
@@ -140,8 +139,8 @@ class FlaxPipelineTests(unittest.TestCase):
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
@@ -169,8 +168,8 @@ class FlaxPipelineTests(unittest.TestCase):
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
scheduler = FlaxDDIMScheduler(