1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Flax: avoid recompilation when params change (#1096)

* Do not recompile when guidance_scale changes.

* Remove debug for simplicity.

* make style

* Make guidance_scale an array.

* Make DEBUG a constant to avoid passing it down.

* Add comments for clarification.
This commit is contained in:
Pedro Cuenca
2022-12-07 14:50:55 +01:00
committed by GitHub
parent 170ebd288f
commit 6a7f1f0965

View File

@@ -42,6 +42,9 @@ from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG = False
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
r"""
@@ -187,7 +190,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
width: Optional[int] = None,
guidance_scale: float = 7.5,
latents: Optional[jnp.array] = None,
debug: bool = False,
neg_prompt_ids: jnp.array = None,
):
# 0. Default height and width to unet
@@ -260,8 +262,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
if debug:
if DEBUG:
# run with python for loop
for i in range(num_inference_steps):
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
@@ -283,11 +284,10 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
guidance_scale: float = 7.5,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
return_dict: bool = True,
jit: bool = False,
debug: bool = False,
neg_prompt_ids: jnp.array = None,
):
r"""
@@ -334,6 +334,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(guidance_scale, float):
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
# shape information, as they may be sharded (when `jit` is `True`), or not.
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
if len(prompt_ids.shape) > 2:
# Assume sharded
guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2])
if jit:
images = _p_generate(
self,
@@ -345,7 +353,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
)
else:
@@ -358,7 +365,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
)
@@ -388,8 +394,13 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
# TODO: maybe use a config dict instead of so many static argnums
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation.
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
@partial(
jax.pmap,
in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0),
static_broadcasted_argnums=(0, 4, 5, 6),
)
def _p_generate(
pipe,
prompt_ids,
@@ -400,7 +411,6 @@ def _p_generate(
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
):
return pipe._generate(
@@ -412,7 +422,6 @@ def _p_generate(
width,
guidance_scale,
latents,
debug,
neg_prompt_ids,
)