1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Fix scheduler type mismatch (#3041)

When doing generation manually and using guidance_scale as a static
argument.
This commit is contained in:
Pedro Cuenca
2023-04-11 23:20:35 +02:00
committed by Daniel Gu
parent 8d362f2cf5
commit 7e6ab07707

View File

@@ -245,6 +245,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
# Ensure model output will be `float32` before going into the scheduler
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)
latents_shape = (
batch_size,
self.unet.config.in_channels,