diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 1cc1d94414..8e53c3239a 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -14,26 +14,43 @@ import warnings from dataclasses import dataclass -import torch - from ..utils import BaseOutput +from ..utils import is_torch_available, is_flax_available SCHEDULER_CONFIG_NAME = "scheduler_config.json" -@dataclass -class SchedulerOutput(BaseOutput): - """ - Base class for the scheduler's step function output. +if is_torch_available(): + import torch - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - """ + @dataclass + class SchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. - prev_sample: torch.FloatTensor + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + +if is_flax_available(): + import jax.numpy as jnp + + class SchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: jnp.ndarray class SchedulerMixin: