1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Jax infer support negative prompt (#1337)

* support negative prompts in sd jax pipeline

* pass batched neg_prompt

* only encode when negative prompt is None

Co-authored-by: Juan Acevedo <jfacevedo@google.com>
This commit is contained in:
Juan Acevedo
2022-11-19 11:51:52 -08:00
committed by GitHub
parent 30220905c4
commit 7bbbfbfd18

View File

@@ -165,6 +165,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
guidance_scale: float = 7.5,
latents: Optional[jnp.array] = None,
debug: bool = False,
neg_prompt_ids: jnp.array = None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -177,10 +178,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
batch_size = prompt_ids.shape[0]
max_length = prompt_ids.shape[-1]
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
if neg_prompt_ids is None:
uncond_input = self.tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
).input_ids
else:
uncond_input = neg_prompt_ids
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
@@ -251,6 +256,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
return_dict: bool = True,
jit: bool = False,
debug: bool = False,
neg_prompt_ids: jnp.array = None,
**kwargs,
):
r"""
@@ -298,11 +304,30 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
"""
if jit:
images = _p_generate(
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
self,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
)
else:
images = self._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
)
if self.safety_checker is not None:
@@ -333,10 +358,29 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
# TODO: maybe use a config dict instead of so many static argnums
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
def _p_generate(
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
pipe,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
):
return pipe._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
)