diff --git a/README.md b/README.md index 8f6db61267..7ca8d7b23b 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_ste # 1. predict noise residual orig_t = len(noise_scheduler) // num_inference_steps * t - with torch.inference_mode(): + with torch.no_grad(): residual = unet(image, orig_t) # 2. predict previous mean of image x_t-1