From abbbc27e8821ffad6a484975f33126ff44d7424e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 10 Jun 2022 14:50:57 +0200 Subject: [PATCH] Update README.md --- README.md | 53 +++++++++++++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index acae868d17..caa9e67f77 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu" # 1. Load models noise_scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church") -model = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) +unet = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) # 2. Sample gaussian noise image = noise_scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) @@ -53,21 +53,21 @@ image = noise_scheduler.sample_noise((1, model.in_channels, model.resolution, mo # 3. Denoise num_prediction_steps = len(noise_scheduler) for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): - # predict noise residual - with torch.no_grad(): - residual = unet(image, t) + # predict noise residual + with torch.no_grad(): + residual = unet(image, t) - # predict previous mean of image x_t-1 - pred_prev_image = noise_scheduler.compute_prev_image_step(residual, image, t) + # predict previous mean of image x_t-1 + pred_prev_image = noise_scheduler.compute_prev_image_step(residual, image, t) - # optionally sample variance - variance = 0 - if t > 0: - noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) - variance = noise_scheduler.get_variance(t).sqrt() * noise + # optionally sample variance + variance = 0 + if t > 0: + noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + variance = noise_scheduler.get_variance(t).sqrt() * noise - # set current image to prev_image: x_t -> x_t-1 - image = pred_prev_image + variance + # set current image to prev_image: x_t -> x_t-1 + image = pred_prev_image + variance # 5. process image to PIL image_processed = image.cpu().permute(0, 2, 3, 1) @@ -93,7 +93,7 @@ torch_device = "cuda" if torch.cuda.is_available() else "cpu" # 1. Load models noise_scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq") -model = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device) +unet = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device) # 2. Sample gaussian noise image = noise_scheduler.sample_noise((1, model.in_channels, model.resolution, model.resolution), device=torch_device, generator=generator) @@ -103,21 +103,22 @@ num_inference_steps = 50 eta = 0.0 # <- deterministic sampling for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): - # 1. predict noise residual - with torch.no_grad(): - residual = unet(image, inference_step_times[t]) + # 1. predict noise residual + orig_t = noise_scheduler.get_orig_t(t, num_inference_steps) + with torch.no_grad(): + residual = unet(image, orig_t) - # 2. predict previous mean of image x_t-1 - pred_prev_image = noise_scheduler.compute_prev_image_step(residual, image, t, num_inference_steps, eta) + # 2. predict previous mean of image x_t-1 + pred_prev_image = noise_scheduler.compute_prev_image_step(residual, image, t, num_inference_steps, eta) - # 3. optionally sample variance - variance = 0 - if eta > 0: - noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) - variance = noise_scheduler.get_variance(t).sqrt() * eta * noise + # 3. optionally sample variance + variance = 0 + if eta > 0: + noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + variance = noise_scheduler.get_variance(t).sqrt() * eta * noise - # 4. set current image to prev_image: x_t -> x_t-1 - image = pred_prev_image + variance + # 4. set current image to prev_image: x_t -> x_t-1 + image = pred_prev_image + variance # 5. process image to PIL image_processed = image.cpu().permute(0, 2, 3, 1)