1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Type hint] scheduling ddim (#343)

* [Type hint] scheduling ddim

* apply suggestions from code review

apply suggestions to also return the return type

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Santiago Víquez
2022-09-04 18:07:54 +02:00
committed by GitHub
parent 5791f4acde
commit 9ea9c6d1c2

View File

@@ -16,7 +16,7 @@
# and https://github.com/hojonathanho/diffusion
import math
from typing import Union
from typing import Optional, Union
import numpy as np
import torch
@@ -52,15 +52,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
clip_sample=True,
set_alpha_to_one=True,
tensor_format="pt",
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
timestep_values: Optional[np.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
tensor_format: str = "pt",
):
if beta_schedule == "linear":
@@ -100,7 +100,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance
def set_timesteps(self, num_inference_steps, offset=0):
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
@@ -176,7 +176,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return {"prev_sample": prev_sample}
def add_noise(self, original_samples, noise, timesteps):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5