mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Jax infer support negative prompt (#1337)
* support negative prompts in sd jax pipeline * pass batched neg_prompt * only encode when negative prompt is None Co-authored-by: Juan Acevedo <jfacevedo@google.com>
This commit is contained in:
@@ -165,6 +165,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
guidance_scale: float = 7.5,
|
||||
latents: Optional[jnp.array] = None,
|
||||
debug: bool = False,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
@@ -177,10 +178,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
batch_size = prompt_ids.shape[0]
|
||||
|
||||
max_length = prompt_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
|
||||
|
||||
if neg_prompt_ids is None:
|
||||
uncond_input = self.tokenizer(
|
||||
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np"
|
||||
).input_ids
|
||||
else:
|
||||
uncond_input = neg_prompt_ids
|
||||
uncond_embeddings = self.text_encoder(uncond_input, params=params["text_encoder"])[0]
|
||||
context = jnp.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
@@ -251,6 +256,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
return_dict: bool = True,
|
||||
jit: bool = False,
|
||||
debug: bool = False,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -298,11 +304,30 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
"""
|
||||
if jit:
|
||||
images = _p_generate(
|
||||
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
self,
|
||||
prompt_ids,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
debug,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
else:
|
||||
images = self._generate(
|
||||
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
prompt_ids,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
debug,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
@@ -333,10 +358,29 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
# TODO: maybe use a config dict instead of so many static argnums
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
|
||||
def _p_generate(
|
||||
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
pipe,
|
||||
prompt_ids,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
debug,
|
||||
neg_prompt_ids,
|
||||
):
|
||||
return pipe._generate(
|
||||
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
prompt_ids,
|
||||
params,
|
||||
prng_seed,
|
||||
num_inference_steps,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
debug,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user