mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Define SchedulerOutput to use torch or flax arrays.
This commit is contained in:
@@ -14,26 +14,43 @@
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import is_torch_available, is_flax_available
|
||||
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Base class for the scheduler's step function output.
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
@dataclass
|
||||
class SchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Base class for the scheduler's step function output.
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
class SchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Base class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: jnp.ndarray
|
||||
|
||||
|
||||
class SchedulerMixin:
|
||||
|
||||
Reference in New Issue
Block a user