diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index fbdc0aba29..79144e9c02 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -39,7 +39,7 @@ def main(args): resamp_with_conv=True, resolution=args.resolution, ) - noise_scheduler = DDPMScheduler(timesteps=1000) + noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt") optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) augmentations = Compose( @@ -93,15 +93,13 @@ def main(args): pbar.set_description(f"Epoch {epoch}") for step, batch in enumerate(train_dataloader): clean_images = batch["input"] - noisy_images = torch.empty_like(clean_images) - noise_samples = torch.empty_like(clean_images) + noise_samples = torch.randn(clean_images.shape).to(clean_images.device) bsz = clean_images.shape[0] - timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() - for idx in range(bsz): - noise = torch.randn(clean_images.shape[1:]).to(clean_images.device) - noise_samples[idx] = noise - noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx]) + + # add noise onto the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps) if step % args.gradient_accumulation_steps != 0: with accelerator.no_sync(model): @@ -146,7 +144,7 @@ def main(args): # save image test_dir = os.path.join(args.output_dir, "test_samples") os.makedirs(test_dir, exist_ok=True) - image_pil.save(f"{test_dir}/{epoch}.png") + image_pil.save(f"{test_dir}/{epoch:04d}.png") # save the model if args.push_to_hub: diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index eb85796f27..206b1477f2 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -17,6 +17,7 @@ import math import numpy as np +import torch from ..configuration_utils import ConfigMixin from .scheduling_utils import SchedulerMixin @@ -142,11 +143,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): return pred_prev_sample - def forward_step(self, original_sample, noise, t): - sqrt_alpha_prod = self.alphas_cumprod[t] ** 0.5 - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[t]) ** 0.5 - noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise - return noisy_sample + def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor): + if timesteps.dim() != 1: + raise ValueError("`timesteps` must be a 1D tensor") + + device = original_samples.device + batch_size = original_samples.shape[0] + timesteps = timesteps.reshape(batch_size, 1, 1, 1) + + sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 + noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise + return noisy_samples def __len__(self): return self.config.timesteps