1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

move expand_to_shape

This commit is contained in:
Nathan Lambert
2022-11-23 11:59:15 -08:00
parent 66951ec084
commit b70f6cd5e0
3 changed files with 13 additions and 24 deletions

View File

@@ -24,18 +24,7 @@ import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from .scheduling_utils import SchedulerMixin
def expand_to_shape(input, timesteps, shape, device):
"""
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once.
"""
out = torch.gather(input.to(device), 0, timesteps.to(device))
reshape = [shape[0]] + [1] * (len(shape) - 1)
out = out.reshape(*reshape)
return out
from .scheduling_utils import SchedulerMixin, expand_to_shape
@dataclass

View File

@@ -23,18 +23,7 @@ import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
def expand_to_shape(input, timesteps, shape, device):
"""
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once.
"""
out = torch.gather(input.to(device), 0, timesteps.to(device))
reshape = [shape[0]] + [1] * (len(shape) - 1)
out = out.reshape(*reshape)
return out
from .scheduling_utils import SchedulerMixin, expand_to_shape
@dataclass

View File

@@ -152,3 +152,14 @@ class SchedulerMixin:
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
]
return compatible_classes
def expand_to_shape(input, timesteps, shape, device):
"""
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once.
"""
out = torch.gather(input.to(device), 0, timesteps.to(device))
reshape = [shape[0]] + [1] * (len(shape) - 1)
out = out.reshape(*reshape)
return out