From 5dda1735fda047f4242d28f91e6e457b9760d52d Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 8 Sep 2022 13:37:36 +0200 Subject: [PATCH] Inference support for `mps` device (#355) * Initial support for mps in Stable Diffusion pipeline. * Initial "warmup" implementation when using mps. * Make some deterministic tests pass with mps. * Disable training tests when using mps. * SD: generate latents in CPU then move to device. This is especially important when using the mps device, because generators are not supported there. See for example https://github.com/pytorch/pytorch/issues/84288. In addition, the other pipelines seem to use the same approach: generate the random samples then move to the appropriate device. After this change, generating an image in MPS produces the same result as when using the CPU, if the same seed is used. * Remove prints. * Pass AutoencoderKL test_output_pretrained with mps. Sampling from `posterior` must be done in CPU. * Style * Do not use torch.long for log op in mps device. * Perform incompatible padding ops in CPU. UNet tests now pass. See https://github.com/pytorch/pytorch/issues/84535 * Style: fix import order. * Remove unused symbols. * Remove MPSWarmupMixin, do not apply automatically. We do apply warmup in the tests, but not during normal use. This adopts some PR suggestions by @patrickvonplaten. * Add comment for mps fallback to CPU step. * Add README_mps.md for mps installation and use. * Apply `black` to modified files. * Restrict README_mps to SD, show measures in table. * Make PNDM indexing compatible with mps. Addresses #239. * Do not use float64 when using LDMScheduler. Fixes #358. * Fix typo identified by @patil-suraj Co-authored-by: Suraj Patil * Adapt example to new output style. * Restore 1:1 results reproducibility with CompVis. However, mps latents need to be generated in CPU because generators don't work in the mps device. * Move PyTorch nightly to requirements. * Adapt `test_scheduler_outputs_equivalence` ton MPS. * mps: skip training tests instead of ignoring silently. * Make VQModel tests pass on mps. * mps ddim tests: warmup, increase tolerance. * ScoreSdeVeScheduler indexing made mps compatible. * Make ldm pipeline tests pass using warmup. * Style * Simplify casting as suggested in PR. * Add Known Issues to readme. * `isort` import order. * Remove _mps_warmup helpers from ModelMixin. And just make changes to the tests. * Skip tests using unittest decorator for consistency. * Remove temporary var. * Remove spurious blank space. * Remove unused symbol. * Remove README_mps. Co-authored-by: Suraj Patil Co-authored-by: Patrick von Platen --- README.md | 4 +++ src/diffusers/models/attention.py | 1 + src/diffusers/models/resnet.py | 5 ++++ src/diffusers/models/unet_2d_condition.py | 3 +- src/diffusers/models/vae.py | 5 +++- src/diffusers/pipeline_utils.py | 1 - .../pipeline_stable_diffusion.py | 9 ++++-- src/diffusers/schedulers/scheduling_pndm.py | 3 +- src/diffusers/schedulers/scheduling_sde_ve.py | 9 ++++-- src/diffusers/testing_utils.py | 1 + tests/test_modeling_common.py | 30 +++++++++++++++++-- tests/test_models_unet.py | 2 +- tests/test_models_vae.py | 8 +++++ tests/test_models_vq.py | 3 ++ tests/test_pipelines.py | 22 ++++++++++++-- 15 files changed, 92 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 434d0cee2b..5dc7e12041 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,10 @@ pip install --upgrade diffusers # should install diffusers 0.2.4 conda install -c conda-forge diffusers ``` +**Apple Silicon (M1/M2) support** + +Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps). + ## Contributing We ❤️ contributions from the open-source community! diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index a69d9014bd..5790a322fa 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -146,6 +146,7 @@ class BasicTransformerBlock(nn.Module): self.attn2._slice_size = slice_size def forward(self, x, context=None): + x = x.contiguous() if x.device.type == "mps" else x x = self.attn1(self.norm1(x)) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 51efea9ee4..27fae24f71 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -448,10 +448,15 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): kernel_h, kernel_w = kernel.shape out = input.view(-1, in_h, 1, in_w, 1, minor) + + # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 + if input.device.type == "mps": + out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(input.device) # Move back to mps if necessary out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 774b350283..ac1a8a9b75 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -171,7 +171,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) + timesteps = timesteps.to(dtype=torch.float32) + timesteps = timesteps[None].to(device=sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index c0a185784c..b90a938aa8 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -338,7 +338,10 @@ class DiagonalGaussianDistribution(object): self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: - x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device) + device = self.parameters.device + sample_device = "cpu" if device.type == "mps" else device + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) + x = self.mean + self.std * sample return x def kl(self, other=None): diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index fc2bc7bcf4..647d1ee313 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -72,7 +72,6 @@ class ImagePipelineOutput(BaseOutput): class DiffusionPipeline(ConfigMixin): - config_name = "model_index.json" def register_modules(self, **kwargs): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 16ca977ea8..fe7653fe1d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -198,17 +198,22 @@ class StableDiffusionPipeline(DiffusionPipeline): text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) if latents is None: latents = torch.randn( latents_shape, generator=generator, - device=self.device, + device=latents_device, ) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + latents = latents.to(self.device) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 171b509898..851dbf6cb1 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -355,7 +355,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> torch.Tensor: - + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 308f42c91f..87786411b4 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -139,7 +139,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) elif tensor_format == "pt": return torch.where( - timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device) + timesteps == 0, + torch.zeros_like(t.to(timesteps.device)), + self.discrete_sigmas[timesteps - 1].to(timesteps.device), ) raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") @@ -196,8 +198,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ) # torch.repeat_interleave(timestep, sample.shape[0]) timesteps = (timestep * (len(self.timesteps) - 1)).long() + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.discrete_sigmas.device) + sigma = self.discrete_sigmas[timesteps].to(sample.device) - adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep) + adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) drift = self.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 diff --git a/src/diffusers/testing_utils.py b/src/diffusers/testing_utils.py index 13f6332a94..a1288b4edb 100644 --- a/src/diffusers/testing_utils.py +++ b/src/diffusers/testing_utils.py @@ -8,6 +8,7 @@ import torch global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" +torch_device = "mps" if torch.backends.mps.is_available() else torch_device def parse_flag_from_env(key, default=False): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8c7c6312de..7c098adbd8 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -15,11 +15,13 @@ import inspect import tempfile +import unittest from typing import Dict, List, Tuple import numpy as np import torch +from diffusers.modeling_utils import ModelMixin from diffusers.testing_utils import torch_device from diffusers.training_utils import EMAModel @@ -38,6 +40,11 @@ class ModelTesterMixin: new_model.to(torch_device) with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + _ = model(**self.dummy_input) + _ = new_model(**self.dummy_input) + image = model(**inputs_dict) if isinstance(image, dict): image = image.sample @@ -55,7 +62,12 @@ class ModelTesterMixin: model = self.model_class(**init_dict) model.to(torch_device) model.eval() + with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + model(**self.dummy_input) + first = model(**inputs_dict) if isinstance(first, dict): first = first.sample @@ -132,6 +144,7 @@ class ModelTesterMixin: self.assertEqual(output_1.shape, output_2.shape) + @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") def test_training(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -147,6 +160,7 @@ class ModelTesterMixin: loss = torch.nn.functional.mse_loss(output, noise) loss.backward() + @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") def test_ema_training(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -167,8 +181,13 @@ class ModelTesterMixin: def test_scheduler_outputs_equivalence(self): def set_nan_tensor_to_zero(t): + # Temporary fallback until `aten::_index_put_impl_` is implemented in mps + # Track progress in https://github.com/pytorch/pytorch/issues/77764 + device = t.device + if device.type == "mps": + t = t.to("cpu") t[t != t] = 0 - return t + return t.to(device) def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, (List, Tuple)): @@ -198,7 +217,12 @@ class ModelTesterMixin: model.to(torch_device) model.eval() - outputs_dict = model(**inputs_dict) - outputs_tuple = model(**inputs_dict, return_dict=False) + with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + model(**self.dummy_input) + + outputs_dict = model(**inputs_dict) + outputs_tuple = model(**inputs_dict, return_dict=False) recursive_check(outputs_tuple, outputs_dict) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index c574a0092e..5e62693739 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -191,7 +191,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): num_channels = 3 noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [10]).to(torch_device) + time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device) return {"sample": noise, "timestep": time_step} diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py index adf9767d2d..c772dc7f63 100644 --- a/tests/test_models_vae.py +++ b/tests/test_models_vae.py @@ -18,6 +18,7 @@ import unittest import torch from diffusers import AutoencoderKL +from diffusers.modeling_utils import ModelMixin from diffusers.testing_utils import floats_tensor, torch_device from .test_modeling_common import ModelTesterMixin @@ -80,6 +81,13 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): model = model.to(torch_device) model.eval() + # One-time warmup pass (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + image = image.to(torch_device) + with torch.no_grad(): + _ = model(image, sample_posterior=True).sample + torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) diff --git a/tests/test_models_vq.py b/tests/test_models_vq.py index c0acceccb4..69468efbb8 100644 --- a/tests/test_models_vq.py +++ b/tests/test_models_vq.py @@ -85,6 +85,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = image.to(torch_device) with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = model(image) output = model(image).sample output_slice = output[0, -1, -3:, -3:].flatten().cpu() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index a05d57a73d..4e304292be 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -194,6 +194,10 @@ class PipelineFastTests(unittest.TestCase): ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) + generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images @@ -207,8 +211,9 @@ class PipelineFastTests(unittest.TestCase): expected_slice = np.array( [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] ) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance def test_pndm_cifar10(self): unet = self.dummy_uncond_unet @@ -244,6 +249,14 @@ class PipelineFastTests(unittest.TestCase): ldm.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + generator = torch.manual_seed(0) + _ = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy")[ + "sample" + ] + generator = torch.manual_seed(0) image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[ "sample" @@ -473,6 +486,11 @@ class PipelineFastTests(unittest.TestCase): ldm.to(torch_device) ldm.set_progress_bar_config(disable=None) + # Warmup pass when using mps (see #372) + if torch_device == "mps": + generator = torch.manual_seed(0) + _ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images + generator = torch.manual_seed(0) image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images