1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

make style

This commit is contained in:
Patrick von Platen
2022-12-02 17:18:50 +00:00
parent 155d272cc1
commit 7222a8eadf

View File

@@ -102,9 +102,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
)
def scale_model_input(
self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int
) -> jnp.ndarray:
def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
@@ -119,7 +117,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
Returns:
`jnp.ndarray`: scaled input sample
"""
step_index, = jnp.where(scheduler_state.timesteps == timestep, size=1)
(step_index,) = jnp.where(scheduler_state.timesteps == timestep, size=1)
sigma = scheduler_state.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample