From 7222a8eadf41ae968eb833cd3e00efffdc547287 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 2 Dec 2022 17:18:50 +0000 Subject: [PATCH] make style --- src/diffusers/schedulers/scheduling_lms_discrete_flax.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index b4a70052b0..9502f69953 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -102,9 +102,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5, ) - def scale_model_input( - self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int - ) -> jnp.ndarray: + def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. @@ -119,7 +117,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): Returns: `jnp.ndarray`: scaled input sample """ - step_index, = jnp.where(scheduler_state.timesteps == timestep, size=1) + (step_index,) = jnp.where(scheduler_state.timesteps == timestep, size=1) sigma = scheduler_state.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample