diff --git a/README.md b/README.md index eba9e09410..6d3ea31100 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ num_prediction_steps = len(noise_scheduler) for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps): # predict noise residual with torch.no_grad(): - residual = unet(image, t) + residual = unet(image, t) # predict previous mean of image x_t-1 pred_prev_image = noise_scheduler.step(residual, image, t) @@ -107,7 +107,7 @@ for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_s variance = 0 if t > 0: noise = torch.randn(image.shape, generator=generator).to(image.device) - variance = noise_scheduler.get_variance(t).sqrt() * noise + variance = noise_scheduler.get_variance(t).sqrt() * noise # set current image to prev_image: x_t -> x_t-1 image = pred_prev_image + variance