diff --git a/README.md b/README.md index 434d0cee2b..5dc7e12041 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,10 @@ pip install --upgrade diffusers # should install diffusers 0.2.4 conda install -c conda-forge diffusers ``` +**Apple Silicon (M1/M2) support** + +Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps). + ## Contributing We ❤️ contributions from the open-source community! diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index a69d9014bd..5790a322fa 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -146,6 +146,7 @@ class BasicTransformerBlock(nn.Module): self.attn2._slice_size = slice_size def forward(self, x, context=None): + x = x.contiguous() if x.device.type == "mps" else x x = self.attn1(self.norm1(x)) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 51efea9ee4..27fae24f71 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -448,10 +448,15 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)): kernel_h, kernel_w = kernel.shape out = input.view(-1, in_h, 1, in_w, 1, minor) + + # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 + if input.device.type == "mps": + out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out.to(input.device) # Move back to mps if necessary out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 774b350283..ac1a8a9b75 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -171,7 +171,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): if not torch.is_tensor(timesteps): timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) + timesteps = timesteps.to(dtype=torch.float32) + timesteps = timesteps[None].to(device=sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index c0a185784c..b90a938aa8 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -338,7 +338,10 @@ class DiagonalGaussianDistribution(object): self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: - x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device) + device = self.parameters.device + sample_device = "cpu" if device.type == "mps" else device + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) + x = self.mean + self.std * sample return x def kl(self, other=None): diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index fc2bc7bcf4..647d1ee313 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -72,7 +72,6 @@ class ImagePipelineOutput(BaseOutput): class DiffusionPipeline(ConfigMixin): - config_name = "model_index.json" def register_modules(self, **kwargs): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 16ca977ea8..fe7653fe1d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -198,17 +198,22 @@ class StableDiffusionPipeline(DiffusionPipeline): text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_device = "cpu" if self.device.type == "mps" else self.device latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) if latents is None: latents = torch.randn( latents_shape, generator=generator, - device=self.device, + device=latents_device, ) else: if latents.shape != latents_shape: raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - latents = latents.to(self.device) + latents = latents.to(self.device) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 171b509898..851dbf6cb1 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -355,7 +355,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> torch.Tensor: - + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 308f42c91f..87786411b4 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -139,7 +139,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1]) elif tensor_format == "pt": return torch.where( - timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device) + timesteps == 0, + torch.zeros_like(t.to(timesteps.device)), + self.discrete_sigmas[timesteps - 1].to(timesteps.device), ) raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") @@ -196,8 +198,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): ) # torch.repeat_interleave(timestep, sample.shape[0]) timesteps = (timestep * (len(self.timesteps) - 1)).long() + # mps requires indices to be in the same device, so we use cpu as is the default with cuda + timesteps = timesteps.to(self.discrete_sigmas.device) + sigma = self.discrete_sigmas[timesteps].to(sample.device) - adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep) + adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device) drift = self.zeros_like(sample) diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5 diff --git a/src/diffusers/testing_utils.py b/src/diffusers/testing_utils.py index 13f6332a94..a1288b4edb 100644 --- a/src/diffusers/testing_utils.py +++ b/src/diffusers/testing_utils.py @@ -8,6 +8,7 @@ import torch global_rng = random.Random() torch_device = "cuda" if torch.cuda.is_available() else "cpu" +torch_device = "mps" if torch.backends.mps.is_available() else torch_device def parse_flag_from_env(key, default=False): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8c7c6312de..7c098adbd8 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -15,11 +15,13 @@ import inspect import tempfile +import unittest from typing import Dict, List, Tuple import numpy as np import torch +from diffusers.modeling_utils import ModelMixin from diffusers.testing_utils import torch_device from diffusers.training_utils import EMAModel @@ -38,6 +40,11 @@ class ModelTesterMixin: new_model.to(torch_device) with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + _ = model(**self.dummy_input) + _ = new_model(**self.dummy_input) + image = model(**inputs_dict) if isinstance(image, dict): image = image.sample @@ -55,7 +62,12 @@ class ModelTesterMixin: model = self.model_class(**init_dict) model.to(torch_device) model.eval() + with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + model(**self.dummy_input) + first = model(**inputs_dict) if isinstance(first, dict): first = first.sample @@ -132,6 +144,7 @@ class ModelTesterMixin: self.assertEqual(output_1.shape, output_2.shape) + @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") def test_training(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -147,6 +160,7 @@ class ModelTesterMixin: loss = torch.nn.functional.mse_loss(output, noise) loss.backward() + @unittest.skipIf(torch_device == "mps", "Training is not supported in mps") def test_ema_training(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -167,8 +181,13 @@ class ModelTesterMixin: def test_scheduler_outputs_equivalence(self): def set_nan_tensor_to_zero(t): + # Temporary fallback until `aten::_index_put_impl_` is implemented in mps + # Track progress in https://github.com/pytorch/pytorch/issues/77764 + device = t.device + if device.type == "mps": + t = t.to("cpu") t[t != t] = 0 - return t + return t.to(device) def recursive_check(tuple_object, dict_object): if isinstance(tuple_object, (List, Tuple)): @@ -198,7 +217,12 @@ class ModelTesterMixin: model.to(torch_device) model.eval() - outputs_dict = model(**inputs_dict) - outputs_tuple = model(**inputs_dict, return_dict=False) + with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + model(**self.dummy_input) + + outputs_dict = model(**inputs_dict) + outputs_tuple = model(**inputs_dict, return_dict=False) recursive_check(outputs_tuple, outputs_dict) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index c574a0092e..5e62693739 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -191,7 +191,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): num_channels = 3 noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) - time_step = torch.tensor(batch_size * [10]).to(torch_device) + time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device) return {"sample": noise, "timestep": time_step} diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py index adf9767d2d..c772dc7f63 100644 --- a/tests/test_models_vae.py +++ b/tests/test_models_vae.py @@ -18,6 +18,7 @@ import unittest import torch from diffusers import AutoencoderKL +from diffusers.modeling_utils import ModelMixin from diffusers.testing_utils import floats_tensor, torch_device from .test_modeling_common import ModelTesterMixin @@ -80,6 +81,13 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): model = model.to(torch_device) model.eval() + # One-time warmup pass (see #372) + if torch_device == "mps" and isinstance(model, ModelMixin): + image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) + image = image.to(torch_device) + with torch.no_grad(): + _ = model(image, sample_posterior=True).sample + torch.manual_seed(0) if torch.cuda.is_available(): torch.cuda.manual_seed_all(0) diff --git a/tests/test_models_vq.py b/tests/test_models_vq.py index c0acceccb4..69468efbb8 100644 --- a/tests/test_models_vq.py +++ b/tests/test_models_vq.py @@ -85,6 +85,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = image.to(torch_device) with torch.no_grad(): + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = model(image) output = model(image).sample output_slice = output[0, -1, -3:, -3:].flatten().cpu() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index a05d57a73d..4e304292be 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -194,6 +194,10 @@ class PipelineFastTests(unittest.TestCase): ddpm.to(torch_device) ddpm.set_progress_bar_config(disable=None) + # Warmup pass when using mps (see #372) + if torch_device == "mps": + _ = ddpm(num_inference_steps=1) + generator = torch.manual_seed(0) image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images @@ -207,8 +211,9 @@ class PipelineFastTests(unittest.TestCase): expected_slice = np.array( [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] ) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 - assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + tolerance = 1e-2 if torch_device != "mps" else 3e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance def test_pndm_cifar10(self): unet = self.dummy_uncond_unet @@ -244,6 +249,14 @@ class PipelineFastTests(unittest.TestCase): ldm.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" + + # Warmup pass when using mps (see #372) + if torch_device == "mps": + generator = torch.manual_seed(0) + _ = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy")[ + "sample" + ] + generator = torch.manual_seed(0) image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[ "sample" @@ -473,6 +486,11 @@ class PipelineFastTests(unittest.TestCase): ldm.to(torch_device) ldm.set_progress_bar_config(disable=None) + # Warmup pass when using mps (see #372) + if torch_device == "mps": + generator = torch.manual_seed(0) + _ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images + generator = torch.manual_seed(0) image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images