From 7e6ab07707cead830fd9a85e45d3e3e339acf6eb Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 11 Apr 2023 23:20:35 +0200 Subject: [PATCH] Fix scheduler type mismatch (#3041) When doing generation manually and using guidance_scale as a static argument. --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index c0c2ee8b8a..3b4f77029c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -245,6 +245,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + # Ensure model output will be `float32` before going into the scheduler + guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32) + latents_shape = ( batch_size, self.unet.config.in_channels,