From b70f6cd5e0412aeb63b1dafe6b10e87f66be5f17 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 23 Nov 2022 11:59:15 -0800 Subject: [PATCH] move expand_to_shape --- src/diffusers/schedulers/scheduling_ddim.py | 13 +------------ src/diffusers/schedulers/scheduling_ddpm.py | 13 +------------ src/diffusers/schedulers/scheduling_utils.py | 11 +++++++++++ 3 files changed, 13 insertions(+), 24 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 6df8c09051..f94b448603 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -24,18 +24,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput -from .scheduling_utils import SchedulerMixin - - -def expand_to_shape(input, timesteps, shape, device): - """ - Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast - nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once. - """ - out = torch.gather(input.to(device), 0, timesteps.to(device)) - reshape = [shape[0]] + [1] * (len(shape) - 1) - out = out.reshape(*reshape) - return out +from .scheduling_utils import SchedulerMixin, expand_to_shape @dataclass diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 4b0ae8f74a..26ce386f77 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -23,18 +23,7 @@ import torch from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate -from .scheduling_utils import SchedulerMixin - - -def expand_to_shape(input, timesteps, shape, device): - """ - Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast - nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once. - """ - out = torch.gather(input.to(device), 0, timesteps.to(device)) - reshape = [shape[0]] + [1] * (len(shape) - 1) - out = out.reshape(*reshape) - return out +from .scheduling_utils import SchedulerMixin, expand_to_shape @dataclass diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 90ab674e38..973b1298fc 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -152,3 +152,14 @@ class SchedulerMixin: getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes + + +def expand_to_shape(input, timesteps, shape, device): + """ + Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast + nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once. + """ + out = torch.gather(input.to(device), 0, timesteps.to(device)) + reshape = [shape[0]] + [1] * (len(shape) - 1) + out = out.reshape(*reshape) + return out