mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[Type hint] scheduling ddim (#343)
* [Type hint] scheduling ddim * apply suggestions from code review apply suggestions to also return the return type Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
# and https://github.com/hojonathanho/diffusion
|
||||
|
||||
import math
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -52,15 +52,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
trained_betas=None,
|
||||
timestep_values=None,
|
||||
clip_sample=True,
|
||||
set_alpha_to_one=True,
|
||||
tensor_format="pt",
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
timestep_values: Optional[np.ndarray] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
|
||||
if beta_schedule == "linear":
|
||||
@@ -100,7 +100,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def set_timesteps(self, num_inference_steps, offset=0):
|
||||
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(
|
||||
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
|
||||
@@ -176,7 +176,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return {"prev_sample": prev_sample}
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: Union[torch.FloatTensor, np.ndarray],
|
||||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
) -> Union[torch.FloatTensor, np.ndarray]:
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
|
||||
Reference in New Issue
Block a user