1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-26 05:21:40 +03:00

Improve docstrings and type hints in scheduling_ddim_inverse.py (#13020)

docs: improve docstring scheduling_ddim_inverse.py
This commit is contained in:
David El Malih
2026-01-23 00:42:45 +01:00
committed by GitHub
parent 1d32b19ad4
commit d4f97d1921

View File

@@ -99,7 +99,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
@@ -187,14 +187,14 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
clip_sample_range: float = 1.0,
timestep_spacing: str = "leading",
timestep_spacing: Literal["leading", "trailing"] = "leading",
rescale_betas_zero_snr: bool = False,
**kwargs,
):
@@ -210,7 +210,15 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -256,7 +264,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int,
device: Optional[Union[str, torch.device]] = None,
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -308,20 +320,10 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
`tuple`.
@@ -335,7 +337,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
# 1. get previous step value (=t+1)
prev_timestep = timestep
timestep = min(
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
timestep - self.config.num_train_timesteps // self.num_inference_steps,
self.config.num_train_timesteps - 1,
)
# 2. compute alphas, betas
@@ -378,5 +381,5 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
return (prev_sample, pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps