1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[addresses issue #1642] add add_noise to scheduling-sde-ve (#1827)

* add add_noise to scheduling-sde-ve

* run Black formater
This commit is contained in:
aengusng8
2023-01-03 20:08:41 +07:00
committed by GitHub
parent 1bf4f0da7e
commit f45c675d2c

View File

@@ -262,5 +262,18 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
timesteps = timesteps.to(original_samples.device)
sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps]
noise = torch.randn_like(original_samples) * sigmas[:, None, None, None]
noisy_samples = noise + original_samples
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps