mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix invocation of some slow Flax tests (#3058)
* Fix invocation of some slow tests. We use __call__ rather than pmapping the generation function ourselves because the number of static arguments is different now. * style
This commit is contained in:
@@ -28,7 +28,6 @@ if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
from jax import pmap
|
||||
|
||||
from diffusers import FlaxDDIMScheduler, FlaxDiffusionPipeline, FlaxStableDiffusionPipeline
|
||||
|
||||
@@ -70,14 +69,12 @@ class FlaxPipelineTests(unittest.TestCase):
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, num_samples)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
|
||||
assert images.shape == (num_samples, 1, 64, 64, 3)
|
||||
if jax.device_count() == 8:
|
||||
@@ -105,14 +102,12 @@ class FlaxPipelineTests(unittest.TestCase):
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, num_samples)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
|
||||
assert images.shape == (num_samples, 1, 512, 512, 3)
|
||||
if jax.device_count() == 8:
|
||||
@@ -136,14 +131,12 @@ class FlaxPipelineTests(unittest.TestCase):
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, num_samples)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
|
||||
assert images.shape == (num_samples, 1, 512, 512, 3)
|
||||
if jax.device_count() == 8:
|
||||
@@ -211,14 +204,12 @@ class FlaxPipelineTests(unittest.TestCase):
|
||||
prompt = num_samples * [prompt]
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
|
||||
p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,))
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, num_samples)
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
|
||||
assert images.shape == (num_samples, 1, 512, 512, 3)
|
||||
if jax.device_count() == 8:
|
||||
|
||||
Reference in New Issue
Block a user