mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Inference support for mps device (#355)
* Initial support for mps in Stable Diffusion pipeline. * Initial "warmup" implementation when using mps. * Make some deterministic tests pass with mps. * Disable training tests when using mps. * SD: generate latents in CPU then move to device. This is especially important when using the mps device, because generators are not supported there. See for example https://github.com/pytorch/pytorch/issues/84288. In addition, the other pipelines seem to use the same approach: generate the random samples then move to the appropriate device. After this change, generating an image in MPS produces the same result as when using the CPU, if the same seed is used. * Remove prints. * Pass AutoencoderKL test_output_pretrained with mps. Sampling from `posterior` must be done in CPU. * Style * Do not use torch.long for log op in mps device. * Perform incompatible padding ops in CPU. UNet tests now pass. See https://github.com/pytorch/pytorch/issues/84535 * Style: fix import order. * Remove unused symbols. * Remove MPSWarmupMixin, do not apply automatically. We do apply warmup in the tests, but not during normal use. This adopts some PR suggestions by @patrickvonplaten. * Add comment for mps fallback to CPU step. * Add README_mps.md for mps installation and use. * Apply `black` to modified files. * Restrict README_mps to SD, show measures in table. * Make PNDM indexing compatible with mps. Addresses #239. * Do not use float64 when using LDMScheduler. Fixes #358. * Fix typo identified by @patil-suraj Co-authored-by: Suraj Patil <surajp815@gmail.com> * Adapt example to new output style. * Restore 1:1 results reproducibility with CompVis. However, mps latents need to be generated in CPU because generators don't work in the mps device. * Move PyTorch nightly to requirements. * Adapt `test_scheduler_outputs_equivalence` ton MPS. * mps: skip training tests instead of ignoring silently. * Make VQModel tests pass on mps. * mps ddim tests: warmup, increase tolerance. * ScoreSdeVeScheduler indexing made mps compatible. * Make ldm pipeline tests pass using warmup. * Style * Simplify casting as suggested in PR. * Add Known Issues to readme. * `isort` import order. * Remove _mps_warmup helpers from ModelMixin. And just make changes to the tests. * Skip tests using unittest decorator for consistency. * Remove temporary var. * Remove spurious blank space. * Remove unused symbol. * Remove README_mps. Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -39,6 +39,10 @@ pip install --upgrade diffusers # should install diffusers 0.2.4
|
||||
conda install -c conda-forge diffusers
|
||||
```
|
||||
|
||||
**Apple Silicon (M1/M2) support**
|
||||
|
||||
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
|
||||
|
||||
## Contributing
|
||||
|
||||
We ❤️ contributions from the open-source community!
|
||||
|
||||
@@ -146,6 +146,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.attn2._slice_size = slice_size
|
||||
|
||||
def forward(self, x, context=None):
|
||||
x = x.contiguous() if x.device.type == "mps" else x
|
||||
x = self.attn1(self.norm1(x)) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
|
||||
@@ -448,10 +448,15 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = input.view(-1, in_h, 1, in_w, 1, minor)
|
||||
|
||||
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
|
||||
if input.device.type == "mps":
|
||||
out = out.to("cpu")
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out.to(input.device) # Move back to mps if necessary
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
|
||||
@@ -171,7 +171,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
timesteps = timesteps.to(dtype=torch.float32)
|
||||
timesteps = timesteps[None].to(device=sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
@@ -338,7 +338,10 @@ class DiagonalGaussianDistribution(object):
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
|
||||
device = self.parameters.device
|
||||
sample_device = "cpu" if device.type == "mps" else device
|
||||
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
|
||||
x = self.mean + self.std * sample
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
|
||||
@@ -72,7 +72,6 @@ class ImagePipelineOutput(BaseOutput):
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
config_name = "model_index.json"
|
||||
|
||||
def register_modules(self, **kwargs):
|
||||
|
||||
@@ -198,17 +198,22 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# get the initial random noise unless the user supplied it
|
||||
|
||||
# Unlike in other pipelines, latents need to be generated in the target device
|
||||
# for 1-to-1 results reproducibility with the CompVis implementation.
|
||||
# However this currently doesn't work in `mps`.
|
||||
latents_device = "cpu" if self.device.type == "mps" else self.device
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
if latents is None:
|
||||
latents = torch.randn(
|
||||
latents_shape,
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
device=latents_device,
|
||||
)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
latents = latents.to(self.device)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# set timesteps
|
||||
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
|
||||
|
||||
@@ -355,7 +355,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
) -> torch.Tensor:
|
||||
|
||||
# mps requires indices to be in the same device, so we use cpu as is the default with cuda
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
|
||||
@@ -139,7 +139,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
|
||||
elif tensor_format == "pt":
|
||||
return torch.where(
|
||||
timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device)
|
||||
timesteps == 0,
|
||||
torch.zeros_like(t.to(timesteps.device)),
|
||||
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
|
||||
)
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
@@ -196,8 +198,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
) # torch.repeat_interleave(timestep, sample.shape[0])
|
||||
timesteps = (timestep * (len(self.timesteps) - 1)).long()
|
||||
|
||||
# mps requires indices to be in the same device, so we use cpu as is the default with cuda
|
||||
timesteps = timesteps.to(self.discrete_sigmas.device)
|
||||
|
||||
sigma = self.discrete_sigmas[timesteps].to(sample.device)
|
||||
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep)
|
||||
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
|
||||
drift = self.zeros_like(sample)
|
||||
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import torch
|
||||
|
||||
global_rng = random.Random()
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
|
||||
@@ -15,11 +15,13 @@
|
||||
|
||||
import inspect
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
from diffusers.testing_utils import torch_device
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
||||
@@ -38,6 +40,11 @@ class ModelTesterMixin:
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
_ = model(**self.dummy_input)
|
||||
_ = new_model(**self.dummy_input)
|
||||
|
||||
image = model(**inputs_dict)
|
||||
if isinstance(image, dict):
|
||||
image = image.sample
|
||||
@@ -55,7 +62,12 @@ class ModelTesterMixin:
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
model(**self.dummy_input)
|
||||
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
@@ -132,6 +144,7 @@ class ModelTesterMixin:
|
||||
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
|
||||
def test_training(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -147,6 +160,7 @@ class ModelTesterMixin:
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
|
||||
def test_ema_training(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -167,8 +181,13 @@ class ModelTesterMixin:
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
|
||||
# Track progress in https://github.com/pytorch/pytorch/issues/77764
|
||||
device = t.device
|
||||
if device.type == "mps":
|
||||
t = t.to("cpu")
|
||||
t[t != t] = 0
|
||||
return t
|
||||
return t.to(device)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
@@ -198,7 +217,12 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
outputs_dict = model(**inputs_dict)
|
||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
model(**self.dummy_input)
|
||||
|
||||
outputs_dict = model(**inputs_dict)
|
||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
@@ -191,7 +191,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
num_channels = 3
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
from diffusers.testing_utils import floats_tensor, torch_device
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
@@ -80,6 +81,13 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
model = model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# One-time warmup pass (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
image = image.to(torch_device)
|
||||
with torch.no_grad():
|
||||
_ = model(image, sample_posterior=True).sample
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
@@ -85,6 +85,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
image = image.to(torch_device)
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = model(image)
|
||||
output = model(image).sample
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
|
||||
|
||||
@@ -194,6 +194,10 @@ class PipelineFastTests(unittest.TestCase):
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
@@ -207,8 +211,9 @@ class PipelineFastTests(unittest.TestCase):
|
||||
expected_slice = np.array(
|
||||
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
tolerance = 1e-2 if torch_device != "mps" else 3e-2
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
|
||||
|
||||
def test_pndm_cifar10(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
@@ -244,6 +249,14 @@ class PipelineFastTests(unittest.TestCase):
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
_ = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy")[
|
||||
"sample"
|
||||
]
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[
|
||||
"sample"
|
||||
@@ -473,6 +486,11 @@ class PipelineFastTests(unittest.TestCase):
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
_ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
|
||||
Reference in New Issue
Block a user