1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Define SchedulerOutput to use torch or flax arrays.

This commit is contained in:
Pedro Cuenca
2022-09-30 15:00:50 +02:00
parent 2b24dba599
commit f653140134

View File

@@ -14,26 +14,43 @@
import warnings
from dataclasses import dataclass
import torch
from ..utils import BaseOutput
from ..utils import is_torch_available, is_flax_available
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@dataclass
class SchedulerOutput(BaseOutput):
"""
Base class for the scheduler's step function output.
if is_torch_available():
import torch
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
@dataclass
class SchedulerOutput(BaseOutput):
"""
Base class for the scheduler's step function output.
prev_sample: torch.FloatTensor
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
if is_flax_available():
import jax.numpy as jnp
class SchedulerOutput(BaseOutput):
"""
Base class for the scheduler's step function output.
Args:
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: jnp.ndarray
class SchedulerMixin: