From a3cc641f78bd0c4a749e8ad03141d7fdb76eec1c Mon Sep 17 00:00:00 2001 From: RogerSinghChugh <35698080+RogerSinghChugh@users.noreply.github.com> Date: Mon, 4 Nov 2024 23:10:44 +0530 Subject: [PATCH] Refac training utils.py (#9815) * Refac training utils.py * quality --------- Co-authored-by: sayakpaul --- src/diffusers/training_utils.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index d2bf3fe071..2474ed5c21 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -43,6 +43,9 @@ def set_seed(seed: int): Args: seed (`int`): The seed to set. + + Returns: + `None` """ random.seed(seed) np.random.seed(seed) @@ -58,6 +61,17 @@ def compute_snr(noise_scheduler, timesteps): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + for the given timesteps using the provided noise scheduler. + + Args: + noise_scheduler (`NoiseScheduler`): + An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute + the SNR values. + timesteps (`torch.Tensor`): + A tensor of timesteps for which the SNR is computed. + + Returns: + `torch.Tensor`: A tensor containing the computed SNR values for each timestep. """ alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5