mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Refac training utils.py (#9815)
* Refac training utils.py * quality --------- Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -43,6 +43,9 @@ def set_seed(seed: int):
|
||||
|
||||
Args:
|
||||
seed (`int`): The seed to set.
|
||||
|
||||
Returns:
|
||||
`None`
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
@@ -58,6 +61,17 @@ def compute_snr(noise_scheduler, timesteps):
|
||||
"""
|
||||
Computes SNR as per
|
||||
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
for the given timesteps using the provided noise scheduler.
|
||||
|
||||
Args:
|
||||
noise_scheduler (`NoiseScheduler`):
|
||||
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
|
||||
the SNR values.
|
||||
timesteps (`torch.Tensor`):
|
||||
A tensor of timesteps for which the SNR is computed.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
|
||||
Reference in New Issue
Block a user