1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Improve docstrings and type hints in scheduling_consistency_models.py (#12931)

docs: improve docstring scheduling_consistency_models.py
This commit is contained in:
David El Malih
2026-01-09 18:56:56 +01:00
committed by GitHub
parent 441b69eabf
commit d36564f06a

View File

@@ -83,7 +83,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
s_noise: float = 1.0,
rho: float = 7.0,
clip_denoised: bool = True,
):
) -> None:
# standard deviation of the initial noise distribution
self.init_noise_sigma = sigma_max
@@ -102,21 +102,29 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def step_index(self):
def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
Returns:
`int` or `None`:
The current step index, or `None` if not yet initialized.
"""
return self._step_index
@property
def begin_index(self):
def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
Returns:
`int` or `None`:
The begin index, or `None` if not yet set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -151,7 +159,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.is_scale_input_called = True
return sample
def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
def sigma_to_t(self, sigmas: Union[float, np.ndarray]) -> np.ndarray:
"""
Gets scaled timesteps from the Karras sigmas for input to the consistency model.
@@ -160,8 +168,8 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
A single Karras sigma or an array of Karras sigmas.
Returns:
`float` or `np.ndarray`:
A scaled input timestep or scaled input timestep array.
`np.ndarray`:
A scaled input timestep array.
"""
if not isinstance(sigmas, np.ndarray):
sigmas = np.array(sigmas, dtype=np.float64)
@@ -173,14 +181,14 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
):
) -> None:
"""
Sets the timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
@@ -244,9 +252,19 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Modified _convert_to_karras implementation that takes in ramp as argument
def _convert_to_karras(self, ramp):
"""Constructs the noise schedule of Karras et al. (2022)."""
def _convert_to_karras(self, ramp: np.ndarray) -> np.ndarray:
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).
Args:
ramp (`np.ndarray`):
A ramp array of values between 0 and 1 used to interpolate between sigma_min and sigma_max.
Returns:
`np.ndarray`:
The Karras sigma schedule array.
"""
sigma_min: float = self.config.sigma_min
sigma_max: float = self.config.sigma_max
@@ -256,14 +274,25 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def get_scalings(self, sigma):
def get_scalings(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes the scaling factors for the consistency model output.
Args:
sigma (`torch.Tensor`):
The current sigma value in the noise schedule.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing `c_skip` (scaling for the input sample) and `c_out` (scaling for the model output).
"""
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
return c_skip, c_out
def get_scalings_for_boundary_condition(self, sigma):
def get_scalings_for_boundary_condition(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Gets the scalings used in the consistency model parameterization (from Appendix C of the
[paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
@@ -275,7 +304,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
The current sigma in the Karras sigma schedule.
Returns:
`tuple`:
`Tuple[torch.Tensor, torch.Tensor]`:
A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out`
(which weights the consistency model output) is the second element.
"""
@@ -348,13 +377,13 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`float`):
timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`, *optional*, defaults to `True`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`.
@@ -406,7 +435,10 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
# Noise is not used for onestep sampling.
if len(self.timesteps) > 1:
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
model_output.shape,
dtype=model_output.dtype,
device=model_output.device,
generator=generator,
)
else:
noise = torch.zeros_like(model_output)
@@ -475,5 +507,12 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
def __len__(self) -> int:
"""
Returns the number of training timesteps.
Returns:
`int`:
The number of training timesteps configured for the scheduler.
"""
return self.config.num_train_timesteps