1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix reproducible initial noise

This commit is contained in:
anton-l
2022-11-10 00:00:07 +01:00
parent 86d4c5a254
commit f300d05cb9
3 changed files with 29 additions and 23 deletions

View File

@@ -89,11 +89,13 @@ class DDIMPipeline(DiffusionPipeline):
generator = None
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
device=self.device,
)
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
image = image.to(device)
else:
image = torch.randn(image_shape, generator=generator, device=self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)

View File

@@ -94,11 +94,13 @@ class DDPMPipeline(DiffusionPipeline):
generator = None
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
device=self.device,
)
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = torch.randn(image_shape, generator=generator)
image = image.to(device)
else:
image = torch.randn(image_shape, generator=generator, device=self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)

View File

@@ -20,7 +20,7 @@ import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import require_torch, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin
@@ -44,18 +44,21 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return model
def test_inference(self):
device = "cpu"
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(device)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.Generator(device=device).manual_seed(0)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.Generator(device=device).manual_seed(0)
generator = torch.manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
@@ -65,8 +68,9 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_slice = np.array(
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
def test_inference_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
@@ -80,6 +84,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
@@ -87,9 +92,6 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = generator.manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
@@ -102,7 +104,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow
@require_torch_gpu
@require_torch
class DDPMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32"
@@ -114,11 +116,11 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020])
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2