From e3095c5f475d6bfa0a02926cd2397d44d57f44fa Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 11 Apr 2023 23:21:25 +0200 Subject: [PATCH] Fix invocation of some slow Flax tests (#3058) * Fix invocation of some slow tests. We use __call__ rather than pmapping the generation function ourselves because the number of static arguments is different now. * style --- tests/test_pipelines_flax.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index a461930f3a..aab2eb9a07 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -28,7 +28,6 @@ if is_flax_available(): import jax.numpy as jnp from flax.jax_utils import replicate from flax.training.common_utils import shard - from jax import pmap from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline @@ -70,14 +69,12 @@ class FlaxPipelineTests(unittest.TestCase): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 64, 64, 3) if jax.device_count() == 8: @@ -105,14 +102,12 @@ class FlaxPipelineTests(unittest.TestCase): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 512, 512, 3) if jax.device_count() == 8: @@ -136,14 +131,12 @@ class FlaxPipelineTests(unittest.TestCase): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 512, 512, 3) if jax.device_count() == 8: @@ -211,14 +204,12 @@ class FlaxPipelineTests(unittest.TestCase): prompt = num_samples * [prompt] prompt_ids = pipeline.prepare_inputs(prompt) - p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) - # shard inputs and rng params = replicate(params) prng_seed = jax.random.split(prng_seed, num_samples) prompt_ids = shard(prompt_ids) - images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images assert images.shape == (num_samples, 1, 512, 512, 3) if jax.device_count() == 8: