mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Flax: avoid recompilation when params change (#1096)
* Do not recompile when guidance_scale changes. * Remove debug for simplicity. * make style * Make guidance_scale an array. * Make DEBUG a constant to avoid passing it down. * Add comments for clarification.
This commit is contained in:
@@ -42,6 +42,9 @@ from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
|
||||
DEBUG = False
|
||||
|
||||
|
||||
class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
r"""
|
||||
@@ -187,7 +190,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
width: Optional[int] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
latents: Optional[jnp.array] = None,
|
||||
debug: bool = False,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
):
|
||||
# 0. Default height and width to unet
|
||||
@@ -260,8 +262,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
if debug:
|
||||
if DEBUG:
|
||||
# run with python for loop
|
||||
for i in range(num_inference_steps):
|
||||
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
|
||||
@@ -283,11 +284,10 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
num_inference_steps: int = 50,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
guidance_scale: Union[float, jnp.array] = 7.5,
|
||||
latents: jnp.array = None,
|
||||
return_dict: bool = True,
|
||||
jit: bool = False,
|
||||
debug: bool = False,
|
||||
neg_prompt_ids: jnp.array = None,
|
||||
):
|
||||
r"""
|
||||
@@ -334,6 +334,14 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
if isinstance(guidance_scale, float):
|
||||
# Convert to a tensor so each device gets a copy. Follow the prompt_ids for
|
||||
# shape information, as they may be sharded (when `jit` is `True`), or not.
|
||||
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
|
||||
if len(prompt_ids.shape) > 2:
|
||||
# Assume sharded
|
||||
guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2])
|
||||
|
||||
if jit:
|
||||
images = _p_generate(
|
||||
self,
|
||||
@@ -345,7 +353,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
debug,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
else:
|
||||
@@ -358,7 +365,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
debug,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
|
||||
@@ -388,8 +394,13 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
# TODO: maybe use a config dict instead of so many static argnums
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
|
||||
# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation.
|
||||
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
|
||||
@partial(
|
||||
jax.pmap,
|
||||
in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0),
|
||||
static_broadcasted_argnums=(0, 4, 5, 6),
|
||||
)
|
||||
def _p_generate(
|
||||
pipe,
|
||||
prompt_ids,
|
||||
@@ -400,7 +411,6 @@ def _p_generate(
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
debug,
|
||||
neg_prompt_ids,
|
||||
):
|
||||
return pipe._generate(
|
||||
@@ -412,7 +422,6 @@ def _p_generate(
|
||||
width,
|
||||
guidance_scale,
|
||||
latents,
|
||||
debug,
|
||||
neg_prompt_ids,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user