From f45c675d2cf1e24d8bee024658f112d4c86aa784 Mon Sep 17 00:00:00 2001 From: aengusng8 Date: Tue, 3 Jan 2023 20:08:41 +0700 Subject: [PATCH] [addresses issue #1642] add add_noise to scheduling-sde-ve (#1827) * add add_noise to scheduling-sde-ve * run Black formater --- src/diffusers/schedulers/scheduling_sde_ve.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 89d3d4a585..3d9e18ca65 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -262,5 +262,18 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): return SchedulerOutput(prev_sample=prev_sample) + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + timesteps = timesteps.to(original_samples.device) + sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps] + noise = torch.randn_like(original_samples) * sigmas[:, None, None, None] + noisy_samples = noise + original_samples + return noisy_samples + def __len__(self): return self.config.num_train_timesteps