mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
make style
This commit is contained in:
@@ -102,9 +102,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int
|
||||
) -> jnp.ndarray:
|
||||
def scale_model_input(self, state: LMSDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
|
||||
"""
|
||||
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
|
||||
|
||||
@@ -119,7 +117,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
Returns:
|
||||
`jnp.ndarray`: scaled input sample
|
||||
"""
|
||||
step_index, = jnp.where(scheduler_state.timesteps == timestep, size=1)
|
||||
(step_index,) = jnp.where(scheduler_state.timesteps == timestep, size=1)
|
||||
sigma = scheduler_state.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
return sample
|
||||
|
||||
Reference in New Issue
Block a user