mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Update Flax TPU tests (#3069)
Update Flax TPU tests. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -78,11 +78,10 @@ class FlaxPipelineTests(unittest.TestCase):
|
||||
|
||||
assert images.shape == (num_samples, 1, 64, 64, 3)
|
||||
if jax.device_count() == 8:
|
||||
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 3.1111548) < 1e-3
|
||||
assert np.abs(np.abs(images, dtype=np.float32).sum() - 199746.95) < 5e-1
|
||||
assert np.abs(np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.1514745) < 1e-3
|
||||
assert np.abs(np.abs(images, dtype=np.float32).sum() - 49947.875) < 5e-1
|
||||
|
||||
images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
||||
|
||||
assert len(images_pil) == num_samples
|
||||
|
||||
def test_stable_diffusion_v1_4(self):
|
||||
@@ -140,8 +139,8 @@ class FlaxPipelineTests(unittest.TestCase):
|
||||
|
||||
assert images.shape == (num_samples, 1, 512, 512, 3)
|
||||
if jax.device_count() == 8:
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
|
||||
|
||||
def test_stable_diffusion_v1_4_bfloat_16_with_safety(self):
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
@@ -169,8 +168,8 @@ class FlaxPipelineTests(unittest.TestCase):
|
||||
|
||||
assert images.shape == (num_samples, 1, 512, 512, 3)
|
||||
if jax.device_count() == 8:
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 5e-1
|
||||
assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.04003906)) < 1e-3
|
||||
assert np.abs((np.abs(images, dtype=np.float32).sum() - 2373516.75)) < 5e-1
|
||||
|
||||
def test_stable_diffusion_v1_4_bfloat_16_ddim(self):
|
||||
scheduler = FlaxDDIMScheduler(
|
||||
|
||||
Reference in New Issue
Block a user