diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 23148dcfe2..912a4381d0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -42,6 +42,9 @@ from .safety_checker_flax import FlaxStableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): r""" @@ -187,7 +190,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): width: Optional[int] = None, guidance_scale: float = 7.5, latents: Optional[jnp.array] = None, - debug: bool = False, neg_prompt_ids: jnp.array = None, ): # 0. Default height and width to unet @@ -260,8 +262,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma - - if debug: + if DEBUG: # run with python for loop for i in range(num_inference_steps): latents, scheduler_state = loop_body(i, (latents, scheduler_state)) @@ -283,11 +284,10 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, - guidance_scale: float = 7.5, + guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, return_dict: bool = True, jit: bool = False, - debug: bool = False, neg_prompt_ids: jnp.array = None, ): r""" @@ -334,6 +334,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2]) + if jit: images = _p_generate( self, @@ -345,7 +353,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): width, guidance_scale, latents, - debug, neg_prompt_ids, ) else: @@ -358,7 +365,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): width, guidance_scale, latents, - debug, neg_prompt_ids, ) @@ -388,8 +394,13 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) -# TODO: maybe use a config dict instead of so many static argnums -@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) +# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 4, 5, 6), +) def _p_generate( pipe, prompt_ids, @@ -400,7 +411,6 @@ def _p_generate( width, guidance_scale, latents, - debug, neg_prompt_ids, ): return pipe._generate( @@ -412,7 +422,6 @@ def _p_generate( width, guidance_scale, latents, - debug, neg_prompt_ids, )