mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Pytorch] pytorch only timesteps (#724)
* pytorch timesteps * style * get rid of if-else * fix test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -36,7 +36,7 @@ This allows for rapid experimentation and cleaner abstractions in the code, wher
|
||||
To this end, the design of schedulers is such that:
|
||||
|
||||
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
|
||||
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Numpy support currently exists).
|
||||
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
|
||||
|
||||
|
||||
## API
|
||||
|
||||
@@ -278,11 +278,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimzed to move all timesteps to correct device beforehand
|
||||
if torch.is_tensor(self.scheduler.timesteps):
|
||||
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
||||
else:
|
||||
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
|
||||
@@ -304,7 +304,10 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
latents = init_latents
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
t_index = t_start + i
|
||||
|
||||
@@ -342,7 +342,10 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
latents = init_latents
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
|
||||
|
||||
for i, t in tqdm(enumerate(timesteps)):
|
||||
t_index = t_start + i
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
- Schedulers are the algorithms to use diffusion models in inference as well as for training. They include the noise schedules and define algorithm-specific diffusion steps.
|
||||
- Schedulers can be used interchangeable between diffusion models in inference to find the preferred trade-off between speed and generation quality.
|
||||
- Schedulers are available in numpy, but can easily be transformed into PyTorch.
|
||||
- Schedulers are available in PyTorch and Jax.
|
||||
|
||||
## API
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
@@ -166,7 +166,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, **kwargs):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -183,7 +183,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1]
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.timesteps += offset
|
||||
|
||||
def step(
|
||||
|
||||
@@ -142,11 +142,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1]
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
self.variance_type = variance_type
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -156,9 +156,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(
|
||||
timesteps = np.arange(
|
||||
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
|
||||
)[::-1]
|
||||
)[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
|
||||
@@ -97,10 +97,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps: int = None
|
||||
self.timesteps: np.ndarray = None
|
||||
self.timesteps: np.IntTensor = None
|
||||
self.schedule: torch.FloatTensor = None # sigma(t_i)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -110,7 +110,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
||||
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
schedule = [
|
||||
(
|
||||
self.config.sigma_max**2
|
||||
@@ -118,7 +119,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
for i in self.timesteps
|
||||
]
|
||||
self.schedule = torch.tensor(schedule, dtype=torch.float32)
|
||||
self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
|
||||
|
||||
def add_noise_to_input(
|
||||
self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
|
||||
|
||||
@@ -147,7 +147,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.plms_timesteps = None
|
||||
self.timesteps = None
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, **kwargs) -> torch.FloatTensor:
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, **kwargs):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -184,7 +184,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
::-1
|
||||
].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
|
||||
|
||||
self.timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
|
||||
timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.ets = []
|
||||
self.counter = 0
|
||||
|
||||
@@ -89,7 +89,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
|
||||
def set_timesteps(
|
||||
self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
|
||||
):
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -101,7 +103,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||
|
||||
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps)
|
||||
self.timesteps = torch.linspace(1, sampling_eps, num_inference_steps, device=device)
|
||||
|
||||
def set_sigmas(
|
||||
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
|
||||
|
||||
@@ -14,9 +14,8 @@
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
|
||||
|
||||
# TODO(Patrick, Anton, Suraj) - make scheduler framework independent and clean-up a bit
|
||||
|
||||
import math
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -52,8 +51,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.discrete_sigmas = None
|
||||
self.timesteps = None
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
|
||||
def set_timesteps(self, num_inference_steps, device: Union[str, torch.device] = None):
|
||||
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps, device=device)
|
||||
|
||||
def step_pred(self, score, x, t, generator=None):
|
||||
if self.timesteps is None:
|
||||
|
||||
@@ -354,7 +354,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_config = self.get_scheduler_config(steps_offset=1)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(5)
|
||||
assert np.equal(scheduler.timesteps, np.array([801, 601, 401, 201, 1])).all()
|
||||
assert torch.equal(scheduler.timesteps, torch.LongTensor([801, 601, 401, 201, 1]))
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
|
||||
@@ -568,10 +568,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_config = self.get_scheduler_config(steps_offset=1)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler.set_timesteps(10)
|
||||
assert np.equal(
|
||||
assert torch.equal(
|
||||
scheduler.timesteps,
|
||||
np.array([901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]),
|
||||
).all()
|
||||
torch.LongTensor(
|
||||
[901, 851, 851, 801, 801, 751, 751, 701, 701, 651, 651, 601, 601, 501, 401, 301, 201, 101, 1]
|
||||
),
|
||||
)
|
||||
|
||||
def test_betas(self):
|
||||
for beta_start, beta_end in zip([0.0001, 0.001], [0.002, 0.02]):
|
||||
|
||||
Reference in New Issue
Block a user