1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Update scheduling_repaint.py (#1582)

* Update scheduling_repaint.py

* update the expected image

Co-authored-by: anton- <anton@huggingface.co>
This commit is contained in:
Randolph-zeng
2022-12-08 00:41:07 +08:00
committed by GitHub
parent ced7c9601a
commit ca68ab3eef
2 changed files with 4 additions and 5 deletions

View File

@@ -287,7 +287,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin):
prev_unknown_part = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction + variance
# 8. Algorithm 1 Line 5 https://arxiv.org/pdf/2201.09865.pdf
prev_known_part = (alpha_prod_t**0.5) * original_image + ((1 - alpha_prod_t) ** 0.5) * noise
prev_known_part = (alpha_prod_t_prev**0.5) * original_image + ((1 - alpha_prod_t_prev) ** 0.5) * noise
# 9. Algorithm 1 Line 8 https://arxiv.org/pdf/2201.09865.pdf
pred_prev_sample = mask * prev_known_part + (1.0 - mask) * prev_unknown_part

View File

@@ -19,7 +19,7 @@ import numpy as np
import torch
from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel
from diffusers.utils.testing_utils import load_image, require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import load_image, load_numpy, require_torch_gpu, slow, torch_device
torch.backends.cuda.matmul.allow_tf32 = False
@@ -36,11 +36,10 @@ class RepaintPipelineIntegrationTests(unittest.TestCase):
mask_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
)
expected_image = load_image(
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/"
"repaint/celeba_hq_256_result.png"
"repaint/celeba_hq_256_result.npy"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
model_id = "google/ddpm-ema-celebahq-256"
unet = UNet2DModel.from_pretrained(model_id)