mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Improve docstrings and type hints in scheduling_consistency_models.py (#12931)
docs: improve docstring scheduling_consistency_models.py
This commit is contained in:
@@ -83,7 +83,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
s_noise: float = 1.0,
|
||||
rho: float = 7.0,
|
||||
clip_denoised: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = sigma_max
|
||||
|
||||
@@ -102,21 +102,29 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
def step_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The current step index, or `None` if not yet initialized.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
def begin_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The begin index, or `None` if not yet set.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
@@ -151,7 +159,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def sigma_to_t(self, sigmas: Union[float, np.ndarray]):
|
||||
def sigma_to_t(self, sigmas: Union[float, np.ndarray]) -> np.ndarray:
|
||||
"""
|
||||
Gets scaled timesteps from the Karras sigmas for input to the consistency model.
|
||||
|
||||
@@ -160,8 +168,8 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
A single Karras sigma or an array of Karras sigmas.
|
||||
|
||||
Returns:
|
||||
`float` or `np.ndarray`:
|
||||
A scaled input timestep or scaled input timestep array.
|
||||
`np.ndarray`:
|
||||
A scaled input timestep array.
|
||||
"""
|
||||
if not isinstance(sigmas, np.ndarray):
|
||||
sigmas = np.array(sigmas, dtype=np.float64)
|
||||
@@ -173,14 +181,14 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
num_inference_steps (`int`, *optional*):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
@@ -244,9 +252,19 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Modified _convert_to_karras implementation that takes in ramp as argument
|
||||
def _convert_to_karras(self, ramp):
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
def _convert_to_karras(self, ramp: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||
Models](https://huggingface.co/papers/2206.00364).
|
||||
|
||||
Args:
|
||||
ramp (`np.ndarray`):
|
||||
A ramp array of values between 0 and 1 used to interpolate between sigma_min and sigma_max.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
The Karras sigma schedule array.
|
||||
"""
|
||||
sigma_min: float = self.config.sigma_min
|
||||
sigma_max: float = self.config.sigma_max
|
||||
|
||||
@@ -256,14 +274,25 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def get_scalings(self, sigma):
|
||||
def get_scalings(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes the scaling factors for the consistency model output.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The current sigma value in the noise schedule.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing `c_skip` (scaling for the input sample) and `c_out` (scaling for the model output).
|
||||
"""
|
||||
sigma_data = self.config.sigma_data
|
||||
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
|
||||
return c_skip, c_out
|
||||
|
||||
def get_scalings_for_boundary_condition(self, sigma):
|
||||
def get_scalings_for_boundary_condition(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Gets the scalings used in the consistency model parameterization (from Appendix C of the
|
||||
[paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
|
||||
@@ -275,7 +304,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current sigma in the Karras sigma schedule.
|
||||
|
||||
Returns:
|
||||
`tuple`:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A two-element tuple where `c_skip` (which weights the current sample) is the first element and `c_out`
|
||||
(which weights the consistency model output) is the second element.
|
||||
"""
|
||||
@@ -348,13 +377,13 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`float`):
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a
|
||||
[`~schedulers.scheduling_consistency_models.CMStochasticIterativeSchedulerOutput`] or `tuple`.
|
||||
|
||||
@@ -406,7 +435,10 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
# Noise is not used for onestep sampling.
|
||||
if len(self.timesteps) > 1:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
||||
model_output.shape,
|
||||
dtype=model_output.dtype,
|
||||
device=model_output.device,
|
||||
generator=generator,
|
||||
)
|
||||
else:
|
||||
noise = torch.zeros_like(model_output)
|
||||
@@ -475,5 +507,12 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Returns the number of training timesteps.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The number of training timesteps configured for the scheduler.
|
||||
"""
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
Reference in New Issue
Block a user