1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Match the generator device to the pipeline for DDPM and DDIM (#1222)

* Match the generator device to the pipeline for DDPM and DDIM

* style

* fix

* update values

* fix fast tests

* trigger slow tests

* deprecate

* last value fixes

* mps fixes
This commit is contained in:
Anton Lozhkov
2022-11-09 23:00:23 +01:00
committed by GitHub
parent 3d98dc763a
commit 7d0c272939
8 changed files with 115 additions and 110 deletions

View File

@@ -11,10 +11,12 @@ import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from datasets import load_dataset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, __version__
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import deprecate
from huggingface_hub import HfFolder, Repository, whoami
from packaging import version
from torchvision.transforms import (
CenterCrop,
Compose,
@@ -28,6 +30,7 @@ from tqdm.auto import tqdm
logger = get_logger(__name__)
diffusers_version = version.parse(version.parse(__version__).base_version)
def _extract_into_tensor(arr, timesteps, broadcast_shape):
@@ -406,7 +409,11 @@ def main(args):
scheduler=noise_scheduler,
)
generator = torch.manual_seed(0)
deprecate("todo: remove this check", "0.10.0", "when the most used version is >= 0.8.0")
if diffusers_version < version.parse("0.8.0"):
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=pipeline.device).manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(
generator=generator,

View File

@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Optional, Tuple, Union
import torch
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...utils import deprecate
class DDIMPipeline(DiffusionPipeline):
@@ -75,24 +75,29 @@ class DDIMPipeline(DiffusionPipeline):
generated images.
"""
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
)
deprecate(
"generator.device == 'cpu'",
"0.11.0",
message,
)
generator = None
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
device=self.device,
)
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
# Ignore use_clipped_model_output if the scheduler doesn't accept this argument
accepts_use_clipped_model_output = "use_clipped_model_output" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_kwargs = {}
if accepts_use_clipped_model_output:
extra_kwargs["use_clipped_model_output"] = use_clipped_model_output
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t).sample
@@ -100,7 +105,9 @@ class DDIMPipeline(DiffusionPipeline):
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# eta corresponds to η in paper and should be between [0, 1]
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta, **extra_kwargs).prev_sample
image = self.scheduler.step(
model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()

View File

@@ -80,12 +80,25 @@ class DDPMPipeline(DiffusionPipeline):
new_config["predict_epsilon"] = predict_epsilon
self.scheduler._internal_dict = FrozenDict(new_config)
if generator is not None and generator.device.type != self.device.type and self.device.type != "mps":
message = (
f"The `generator` device is `{generator.device}` and does not match the pipeline "
f"device `{self.device}`, so the `generator` will be set to `None`. "
f'Please use `torch.Generator(device="{self.device}")` instead.'
)
deprecate(
"generator.device == 'cpu'",
"0.11.0",
message,
)
generator = None
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size),
generator=generator,
device=self.device,
)
image = image.to(self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)

View File

@@ -292,10 +292,16 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 6. Add noise
variance = 0
if t > 0:
noise = torch.randn(
model_output.size(), dtype=model_output.dtype, layout=model_output.layout, generator=generator
).to(model_output.device)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * noise
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
pred_prev_sample = pred_prev_sample + variance

View File

@@ -19,7 +19,7 @@ import numpy as np
import torch
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
from diffusers.utils.testing_utils import require_torch, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin
@@ -43,21 +43,18 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return model
def test_inference(self):
device = "cpu"
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler()
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.to(device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
generator = torch.Generator(device=device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
@@ -67,13 +64,12 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_slice = np.array(
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch
@require_torch_gpu
class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_ema_bedroom(self):
model_id = "google/ddpm-ema-bedroom-256"
@@ -85,13 +81,13 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
expected_slice = np.array([0.00605, 0.0201, 0.0344, 0.00235, 0.00185, 0.00025, 0.00215, 0.0, 0.00685])
expected_slice = np.array([0.1546, 0.1561, 0.1595, 0.1564, 0.1569, 0.1585, 0.1554, 0.1550, 0.1575])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_cifar10(self):
@@ -104,11 +100,11 @@ class DDIMPipelineIntegrationTests(unittest.TestCase):
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddim(generator=generator, eta=0.0, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.17235, 0.16175, 0.16005, 0.16255, 0.1497, 0.1513, 0.15045, 0.1442, 0.1453])
expected_slice = np.array([0.2060, 0.2042, 0.2022, 0.2193, 0.2146, 0.2110, 0.2471, 0.2446, 0.2388])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

View File

@@ -20,7 +20,7 @@ import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils import deprecate
from diffusers.utils.testing_utils import require_torch, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from ...test_pipelines_common import PipelineTesterMixin
@@ -44,21 +44,18 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return model
def test_inference(self):
device = "cpu"
unet = self.dummy_uncond_unet
scheduler = DDPMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device)
ddpm.to(device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
generator = torch.Generator(device=device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
@@ -68,9 +65,8 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
expected_slice = np.array(
[5.589e-01, 7.089e-01, 2.632e-01, 6.841e-01, 1.000e-04, 9.999e-01, 1.973e-01, 1.000e-04, 8.010e-02]
)
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_inference_predict_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
@@ -85,10 +81,10 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image_eps = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", predict_epsilon=False)[0]
image_slice = image[0, -3:, -3:, -1]
@@ -100,7 +96,7 @@ class DDPMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow
@require_torch
@require_torch_gpu
class DDPMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32"
@@ -112,11 +108,11 @@ class DDPMPipelineIntegrationTests(unittest.TestCase):
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.41995, 0.35885, 0.19385, 0.38475, 0.3382, 0.2647, 0.41545, 0.3582, 0.33845])
expected_slice = np.array([0.4454, 0.2025, 0.0315, 0.3023, 0.2575, 0.1031, 0.0953, 0.1604, 0.2020])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

View File

@@ -68,30 +68,25 @@ class LDMSuperResolutionPipelineFastTests(PipelineTesterMixin, unittest.TestCase
return model
def test_inference_superresolution(self):
device = "cpu"
unet = self.dummy_uncond_unet
scheduler = DDIMScheduler()
vqvae = self.dummy_vq_model
ldm = LDMSuperResolutionPipeline(unet=unet, vqvae=vqvae, scheduler=scheduler)
ldm.to(torch_device)
ldm.to(device)
ldm.set_progress_bar_config(disable=None)
init_image = self.dummy_image.to(torch_device)
init_image = self.dummy_image.to(device)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm(init_image, generator=generator, num_inference_steps=1, output_type="numpy").images
generator = torch.manual_seed(0)
generator = torch.Generator(device=device).manual_seed(0)
image = ldm(init_image, generator=generator, num_inference_steps=2, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.8634, 0.8186, 0.6416, 0.6846, 0.4427, 0.5676, 0.4679, 0.6247, 0.5176])
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
expected_slice = np.array([0.8678, 0.8245, 0.6381, 0.6830, 0.4385, 0.5599, 0.4641, 0.6201, 0.5150])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@slow

View File

@@ -42,7 +42,6 @@ from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -93,11 +92,17 @@ class DownloadTests(unittest.TestCase):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
generator = torch.Generator(device=torch_device).manual_seed(0)
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
pipe_2 = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3
@@ -107,13 +112,19 @@ class DownloadTests(unittest.TestCase):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None
)
generator = torch.Generator(device=torch_device).manual_seed(0)
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname, safety_checker=None)
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3
@@ -121,13 +132,19 @@ class DownloadTests(unittest.TestCase):
def test_load_no_safety_checker_default_locally(self):
prompt = "hello"
pipe = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
generator = torch.Generator(device=torch_device).manual_seed(0)
pipe = pipe.to(torch_device)
if torch_device == "mps":
# device type MPS is not supported for torch.Generator() api.
generator = torch.manual_seed(0)
else:
generator = torch.Generator(device=torch_device).manual_seed(0)
out = pipe(prompt, num_inference_steps=2, generator=generator, output_type="numpy").images
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_2 = StableDiffusionPipeline.from_pretrained(tmpdirname)
generator_2 = torch.Generator(device=torch_device).manual_seed(0)
pipe_2 = pipe_2.to(torch_device)
generator_2 = generator.manual_seed(0)
out_2 = pipe_2(prompt, num_inference_steps=2, generator=generator_2, output_type="numpy").images
assert np.max(np.abs(out - out_2)) < 1e-3
@@ -431,7 +448,7 @@ class PipelineSlowTests(unittest.TestCase):
new_ddpm = DDPMPipeline.from_pretrained(tmpdirname)
new_ddpm.to(torch_device)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
@@ -452,7 +469,7 @@ class PipelineSlowTests(unittest.TestCase):
ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
@@ -475,7 +492,7 @@ class PipelineSlowTests(unittest.TestCase):
ddpm_from_hub = ddpm_from_hub.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
@@ -491,7 +508,7 @@ class PipelineSlowTests(unittest.TestCase):
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
generator = torch.Generator(device=torch_device).manual_seed(0)
images = pipe(generator=generator, output_type="numpy").images
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
@@ -506,40 +523,8 @@ class PipelineSlowTests(unittest.TestCase):
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
# Make sure the test passes for different values of random seed
@parameterized.expand([(0,), (4,)])
def test_ddpm_ddim_equality(self, seed):
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
ddpm_scheduler = DDPMScheduler()
ddim_scheduler = DDIMScheduler()
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(seed)
ddpm_image = ddpm(generator=generator, output_type="numpy").images
generator = torch.manual_seed(seed)
ddim_image = ddim(
generator=generator,
num_inference_steps=1000,
eta=1.0,
output_type="numpy",
use_clipped_model_output=True, # Need this to make DDIM match DDPM
).images
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
# Make sure the test passes for different values of random seed
@parameterized.expand([(0,), (4,)])
def test_ddpm_ddim_equality_batched(self, seed):
def test_ddpm_ddim_equality_batched(self):
seed = 0
model_id = "google/ddpm-cifar10-32"
unet = UNet2DModel.from_pretrained(model_id)
@@ -554,12 +539,12 @@ class PipelineSlowTests(unittest.TestCase):
ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(seed)
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
generator = torch.Generator(device=torch_device).manual_seed(seed)
ddpm_images = ddpm(batch_size=2, generator=generator, output_type="numpy").images
generator = torch.manual_seed(seed)
generator = torch.Generator(device=torch_device).manual_seed(seed)
ddim_images = ddim(
batch_size=4,
batch_size=2,
generator=generator,
num_inference_steps=1000,
eta=1.0,