mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix reproducible initial noise
This commit is contained in:
@@ -89,11 +89,13 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator)
|
||||
image = image.to(device)
|
||||
else:
|
||||
image = torch.randn(image_shape, generator=generator, device=self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -94,11 +94,13 @@ class DDPMPipeline(DiffusionPipeline):
|
||||
generator = None
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
)
|
||||
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)
|
||||
if self.device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
image = torch.randn(image_shape, generator=generator)
|
||||
image = image.to(device)
|
||||
else:
|
||||
image = torch.randn(image_shape, generator=generator, device=self.device)
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.utils import deprecate
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
@@ -44,18 +44,21 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
return model
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
unet = self.dummy_uncond_unet
|
||||
scheduler = DDPMScheduler()
|
||||
|
||||
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
|
||||
ddpm.to(device)
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
@@ -65,8 +68,9 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
|
||||
|
||||
def test_inference_predict_epsilon(self):
|
||||
deprecate("remove this test", "0.10.0", "remove")
|
||||
@@ -80,6 +84,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
if torch_device == "mps":
|
||||
# device type MPS is not supported for torch.Generator() api.
|
||||
generator = torch.manual_seed(0)
|
||||
@@ -87,9 +92,6 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
generator = generator.manual_seed(0)
|
||||
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
|
||||
|
||||
@@ -102,7 +104,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_torch
|
||||
class DDPMPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_cifar10(self):
|
||||
model_id = "google/ddpm-cifar10-32"
|
||||
@@ -114,11 +116,11 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020])
|
||||
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user