mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Docs: short section on changing the scheduler in Flax (#2181)
* Short doc on changing the scheduler in Flax. * Apply fix from @patil-suraj Co-authored-by: Suraj Patil <surajp815@gmail.com> --------- Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -176,6 +176,7 @@ image
|
||||
<br>
|
||||
</p>
|
||||
|
||||
If you are a JAX/Flax user, please check [this section](#changing-the-scheduler-in-flax) instead.
|
||||
|
||||
## Compare schedulers
|
||||
|
||||
@@ -260,3 +261,54 @@ image
|
||||
|
||||
As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
|
||||
schedulers to compare results.
|
||||
|
||||
## Changing the Scheduler in Flax
|
||||
|
||||
If you are a JAX/Flax user, you can also change the default pipeline scheduler. This is a complete example of how to run inference using the Flax Stable Diffusion pipeline and the super-fast [DDPM-Solver++ scheduler](../api/schedulers/multistep_dpm_solver):
|
||||
|
||||
```Python
|
||||
import jax
|
||||
import numpy as np
|
||||
from flax.jax_utils import replicate
|
||||
from flax.training.common_utils import shard
|
||||
|
||||
from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler
|
||||
|
||||
model_id = "runwayml/stable-diffusion-v1-5"
|
||||
scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler"
|
||||
)
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=scheduler,
|
||||
revision="bf16",
|
||||
dtype=jax.numpy.bfloat16,
|
||||
)
|
||||
params["scheduler"] = scheduler_state
|
||||
|
||||
# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
num_samples = jax.device_count()
|
||||
prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)
|
||||
|
||||
prng_seed = jax.random.PRNGKey(0)
|
||||
num_inference_steps = 25
|
||||
|
||||
# shard inputs and rng
|
||||
params = replicate(params)
|
||||
prng_seed = jax.random.split(prng_seed, jax.device_count())
|
||||
prompt_ids = shard(prompt_ids)
|
||||
|
||||
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
|
||||
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The following Flax schedulers are _not yet compatible_ with the Flax Stable Diffusion Pipeline:
|
||||
|
||||
- `FlaxLMSDiscreteScheduler`
|
||||
- `FlaxDDPMScheduler`
|
||||
|
||||
</Tip>
|
||||
|
||||
Reference in New Issue
Block a user