mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
move expand_to_shape
This commit is contained in:
@@ -24,18 +24,7 @@ import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
def expand_to_shape(input, timesteps, shape, device):
|
||||
"""
|
||||
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
|
||||
nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once.
|
||||
"""
|
||||
out = torch.gather(input.to(device), 0, timesteps.to(device))
|
||||
reshape = [shape[0]] + [1] * (len(shape) - 1)
|
||||
out = out.reshape(*reshape)
|
||||
return out
|
||||
from .scheduling_utils import SchedulerMixin, expand_to_shape
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -23,18 +23,7 @@ import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
def expand_to_shape(input, timesteps, shape, device):
|
||||
"""
|
||||
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
|
||||
nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once.
|
||||
"""
|
||||
out = torch.gather(input.to(device), 0, timesteps.to(device))
|
||||
reshape = [shape[0]] + [1] * (len(shape) - 1)
|
||||
out = out.reshape(*reshape)
|
||||
return out
|
||||
from .scheduling_utils import SchedulerMixin, expand_to_shape
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -152,3 +152,14 @@ class SchedulerMixin:
|
||||
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
||||
]
|
||||
return compatible_classes
|
||||
|
||||
|
||||
def expand_to_shape(input, timesteps, shape, device):
|
||||
"""
|
||||
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
|
||||
nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once.
|
||||
"""
|
||||
out = torch.gather(input.to(device), 0, timesteps.to(device))
|
||||
reshape = [shape[0]] + [1] * (len(shape) - 1)
|
||||
out = out.reshape(*reshape)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user