mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Fix scheduler type mismatch (#3041)
When doing generation manually and using guidance_scale as a static argument.
This commit is contained in:
@@ -245,6 +245,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
|
||||
context = jnp.concatenate([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
# Ensure model output will be `float32` before going into the scheduler
|
||||
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)
|
||||
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.config.in_channels,
|
||||
|
||||
Reference in New Issue
Block a user