From 154a7865fc3ade9f47cde5f9fe83dc44d53ccb44 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 4 Feb 2023 20:45:20 +0100 Subject: [PATCH] [Flax DDPM] Make `key` optional so default pipelines don't fail (#2176) Make `key` optional so default pipelines don't fail. --- src/diffusers/schedulers/scheduling_ddpm_flax.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index faf59b10f3..3179538e83 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -198,7 +198,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - key: jax.random.KeyArray, + key: Optional[jax.random.KeyArray] = None, return_dict: bool = True, ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: """ @@ -221,6 +221,9 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): """ t = timestep + if key is None: + key = jax.random.PRNGKey(0) + if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1) else: