mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* add add_noise to scheduling-sde-ve * run Black formater
This commit is contained in:
@@ -262,5 +262,18 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
sigmas = self.discrete_sigmas.to(original_samples.device)[timesteps]
|
||||
noise = torch.randn_like(original_samples) * sigmas[:, None, None, None]
|
||||
noisy_samples = noise + original_samples
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
Reference in New Issue
Block a user