1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Flax DDPM] Make key optional so default pipelines don't fail (#2176)

Make `key` optional so default pipelines don't fail.
This commit is contained in:
Pedro Cuenca
2023-02-04 20:45:20 +01:00
committed by GitHub
parent 9baa29e9c0
commit 154a7865fc

View File

@@ -198,7 +198,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: jax.random.KeyArray,
key: Optional[jax.random.KeyArray] = None,
return_dict: bool = True,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
"""
@@ -221,6 +221,9 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
t = timestep
if key is None:
key = jax.random.PRNGKey(0)
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
else: