diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index b4a70052b0..9502f69953 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -102,9 +102,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5, ) - def scale_model_input( - self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int - ) -> jnp.ndarray: + def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray: """ Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm. @@ -119,7 +117,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): Returns: `jnp.ndarray`: scaled input sample """ - step_index, = jnp.where(scheduler_state.timesteps == timestep, size=1) + (step_index,) = jnp.where(scheduler_state.timesteps == timestep, size=1) sigma = scheduler_state.sigmas[step_index] sample = sample / ((sigma**2 + 1) ** 0.5) return sample