diff --git a/docs/source/en/using-diffusers/schedulers.mdx b/docs/source/en/using-diffusers/schedulers.mdx
index 21250b0550..caa80675a0 100644
--- a/docs/source/en/using-diffusers/schedulers.mdx
+++ b/docs/source/en/using-diffusers/schedulers.mdx
@@ -176,6 +176,7 @@ image
+If you are a JAX/Flax user, please check [this section](#changing-the-scheduler-in-flax) instead.
## Compare schedulers
@@ -260,3 +261,54 @@ image
As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
schedulers to compare results.
+
+## Changing the Scheduler in Flax
+
+If you are a JAX/Flax user, you can also change the default pipeline scheduler. This is a complete example of how to run inference using the Flax Stable Diffusion pipeline and the super-fast [DDPM-Solver++ scheduler](../api/schedulers/multistep_dpm_solver):
+
+```Python
+import jax
+import numpy as np
+from flax.jax_utils import replicate
+from flax.training.common_utils import shard
+
+from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
+
+model_id = "runwayml/stable-diffusion-v1-5"
+scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
+ model_id,
+ subfolder="scheduler"
+)
+pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
+ model_id,
+ scheduler=scheduler,
+ revision="bf16",
+ dtype=jax.numpy.bfloat16,
+)
+params["scheduler"] = scheduler_state
+
+# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
+prompt = "a photo of an astronaut riding a horse on mars"
+num_samples = jax.device_count()
+prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
+
+prng_seed = jax.random.PRNGKey(0)
+num_inference_steps = 25
+
+# shard inputs and rng
+params = replicate(params)
+prng_seed = jax.random.split(prng_seed, jax.device_count())
+prompt_ids = shard(prompt_ids)
+
+images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
+images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
+```
+
+
+
+The following Flax schedulers are _not yet compatible_ with the Flax Stable Diffusion Pipeline:
+
+- `FlaxLMSDiscreteScheduler`
+- `FlaxDDPMScheduler`
+
+