1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Fix invocation of some slow Flax tests (#3058)

* Fix invocation of some slow tests.

We use __call__ rather than pmapping the generation function ourselves
because the number of static arguments is different now.

* style
This commit is contained in:
Pedro Cuenca
2023-04-11 23:21:25 +02:00
committed by Daniel Gu
parent 7e6ab07707
commit d8eedb4787

View File

@@ -28,7 +28,6 @@ if is_flax_available():
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from jax import pmap
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
@@ -70,14 +69,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 64, 64, 3)
if jax.device_count() == 8:
@@ -105,14 +102,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
@@ -136,14 +131,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8:
@@ -211,14 +204,12 @@ class FlaxPipelineTests(unittest.TestCase):
prompt = num_samples * [prompt]
prompt_ids = pipeline.prepare_inputs(prompt)
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, num_samples)
prompt_ids = shard(prompt_ids)
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
assert images.shape == (num_samples, 1, 512, 512, 3)
if jax.device_count() == 8: