From 20d91782374ef076a560a2a2bc61f4c8bf6dd629 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 12 Jun 2022 22:14:03 +0000 Subject: [PATCH] correct readme --- README.md | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index e6b071577d..7d7c8c755a 100644 --- a/README.md +++ b/README.md @@ -48,9 +48,13 @@ noise_scheduler = GaussianDDPMScheduler.from_config("fusing/ddpm-lsun-church") unet = UNetModel.from_pretrained("fusing/ddpm-lsun-church").to(torch_device) # 2. Sample gaussian noise -image = noise_scheduler.sample_noise((1, unet.in_channels, unet.resolution, unet.resolution), device=torch_device, generator=generator) +image = torch.randn( + (1, unet.in_channels, unet.resolution, unet.resolution) + generator=generator, +) +image = image.to(torch_device) -# 3. Denoise +# 3. Denoise num_prediction_steps = len(noise_scheduler) for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): # predict noise residual @@ -63,7 +67,7 @@ for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_s # optionally sample variance variance = 0 if t > 0: - noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + noise = torch.randn(image.shape, generator=generator).to(image.device) variance = noise_scheduler.get_variance(t).sqrt() * noise # set current image to prev_image: x_t -> x_t-1 @@ -96,7 +100,11 @@ noise_scheduler = DDIMScheduler.from_config("fusing/ddpm-celeba-hq") unet = UNetModel.from_pretrained("fusing/ddpm-celeba-hq").to(torch_device) # 2. Sample gaussian noise -image = noise_scheduler.sample_noise((1, unet.in_channels, unet.resolution, unet.resolution), device=torch_device, generator=generator) +image = torch.randn( + (1, unet.in_channels, unet.resolution, unet.resolution) + generator=generator, +) +image = image.to(torch_device) # 3. Denoise num_inference_steps = 50 @@ -114,7 +122,7 @@ for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_ste # 3. optionally sample variance variance = 0 if eta > 0: - noise = noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator) + noise = torch.randn(image.shape, generator=generator).to(image.device) variance = noise_scheduler.get_variance(t).sqrt() * eta * noise # 4. set current image to prev_image: x_t -> x_t-1