From 7bbbfbfd18ed9f5f6ce02bf194382a27150dd4c4 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Sat, 19 Nov 2022 11:51:52 -0800 Subject: [PATCH] Jax infer support negative prompt (#1337) * support negative prompts in sd jax pipeline * pass batched neg_prompt * only encode when negative prompt is None Co-authored-by: Juan Acevedo --- .../pipeline_flax_stable_diffusion.py | 60 ++++++++++++++++--- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 02943997d9..a2f0f73dbf 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -165,6 +165,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): guidance_scale: float = 7.5, latents: Optional[jnp.array] = None, debug: bool = False, + neg_prompt_ids: jnp.array = None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -177,10 +178,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): batch_size = prompt_ids.shape[0] max_length = prompt_ids.shape[-1] - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) @@ -251,6 +256,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): return_dict: bool = True, jit: bool = False, debug: bool = False, + neg_prompt_ids: jnp.array = None, **kwargs, ): r""" @@ -298,11 +304,30 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): """ if jit: images = _p_generate( - self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + self, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ) else: images = self._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ) if self.safety_checker is not None: @@ -333,10 +358,29 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): # TODO: maybe use a config dict instead of so many static argnums @partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) def _p_generate( - pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + pipe, + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, ): return pipe._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + prompt_ids, + params, + prng_seed, + num_inference_steps, + height, + width, + guidance_scale, + latents, + debug, + neg_prompt_ids, )