1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

batched forward diffusion step

This commit is contained in:
anton-l
2022-06-22 13:38:14 +02:00
parent 62c2c547db
commit 848c86ca0a
2 changed files with 20 additions and 14 deletions

View File

@@ -39,7 +39,7 @@ def main(args):
resamp_with_conv=True,
resolution=args.resolution,
)
noise_scheduler = DDPMScheduler(timesteps=1000)
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
augmentations = Compose(
@@ -93,15 +93,13 @@ def main(args):
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images)
noise_samples = torch.empty_like(clean_images)
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
for idx in range(bsz):
noise = torch.randn(clean_images.shape[1:]).to(clean_images.device)
noise_samples[idx] = noise
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
@@ -146,7 +144,7 @@ def main(args):
# save image
test_dir = os.path.join(args.output_dir, "test_samples")
os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch}.png")
image_pil.save(f"{test_dir}/{epoch:04d}.png")
# save the model
if args.push_to_hub:

View File

@@ -17,6 +17,7 @@
import math
import numpy as np
import torch
from ..configuration_utils import ConfigMixin
from .scheduling_utils import SchedulerMixin
@@ -142,11 +143,18 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample
def forward_step(self, original_sample, noise, t):
sqrt_alpha_prod = self.alphas_cumprod[t] ** 0.5
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[t]) ** 0.5
noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
return noisy_sample
def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
if timesteps.dim() != 1:
raise ValueError("`timesteps` must be a 1D tensor")
device = original_samples.device
batch_size = original_samples.shape[0]
timesteps = timesteps.reshape(batch_size, 1, 1, 1)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise
return noisy_samples
def __len__(self):
return self.config.timesteps