1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Inference support for mps device (#355)

* Initial support for mps in Stable Diffusion pipeline.

* Initial "warmup" implementation when using mps.

* Make some deterministic tests pass with mps.

* Disable training tests when using mps.

* SD: generate latents in CPU then move to device.

This is especially important when using the mps device, because
generators are not supported there. See for example
https://github.com/pytorch/pytorch/issues/84288.

In addition, the other pipelines seem to use the same approach: generate
the random samples then move to the appropriate device.

After this change, generating an image in MPS produces the same result
as when using the CPU, if the same seed is used.

* Remove prints.

* Pass AutoencoderKL test_output_pretrained with mps.

Sampling from `posterior` must be done in CPU.

* Style

* Do not use torch.long for log op in mps device.

* Perform incompatible padding ops in CPU.

UNet tests now pass.
See https://github.com/pytorch/pytorch/issues/84535

* Style: fix import order.

* Remove unused symbols.

* Remove MPSWarmupMixin, do not apply automatically.

We do apply warmup in the tests, but not during normal use.
This adopts some PR suggestions by @patrickvonplaten.

* Add comment for mps fallback to CPU step.

* Add README_mps.md for mps installation and use.

* Apply `black` to modified files.

* Restrict README_mps to SD, show measures in table.

* Make PNDM indexing compatible with mps.

Addresses #239.

* Do not use float64 when using LDMScheduler.

Fixes #358.

* Fix typo identified by @patil-suraj

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Adapt example to new output style.

* Restore 1:1 results reproducibility with CompVis.

However, mps latents need to be generated in CPU because generators
don't work in the mps device.

* Move PyTorch nightly to requirements.

* Adapt `test_scheduler_outputs_equivalence` ton MPS.

* mps: skip training tests instead of ignoring silently.

* Make VQModel tests pass on mps.

* mps ddim tests: warmup, increase tolerance.

* ScoreSdeVeScheduler indexing made mps compatible.

* Make ldm pipeline tests pass using warmup.

* Style

* Simplify casting as suggested in PR.

* Add Known Issues to readme.

* `isort` import order.

* Remove _mps_warmup helpers from ModelMixin.

And just make changes to the tests.

* Skip tests using unittest decorator for consistency.

* Remove temporary var.

* Remove spurious blank space.

* Remove unused symbol.

* Remove README_mps.

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Pedro Cuenca
2022-09-08 13:37:36 +02:00
committed by GitHub
parent 98f346835a
commit 5dda1735fd
15 changed files with 92 additions and 14 deletions

View File

@@ -39,6 +39,10 @@ pip install --upgrade diffusers # should install diffusers 0.2.4
conda install -c conda-forge diffusers
```
**Apple Silicon (M1/M2) support**
Please, refer to [the documentation](https://huggingface.co/docs/diffusers/optimization/mps).
## Contributing
We ❤️ contributions from the open-source community!

View File

@@ -146,6 +146,7 @@ class BasicTransformerBlock(nn.Module):
self.attn2._slice_size = slice_size
def forward(self, x, context=None):
x = x.contiguous() if x.device.type == "mps" else x
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x

View File

@@ -448,10 +448,15 @@ def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
if input.device.type == "mps":
out = out.to("cpu")
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
out = out.to(input.device) # Move back to mps if necessary
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),

View File

@@ -171,7 +171,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
timesteps = timesteps.to(dtype=torch.float32)
timesteps = timesteps[None].to(device=sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])

View File

@@ -338,7 +338,10 @@ class DiagonalGaussianDistribution(object):
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
x = self.mean + self.std * sample
return x
def kl(self, other=None):

View File

@@ -72,7 +72,6 @@ class ImagePipelineOutput(BaseOutput):
class DiffusionPipeline(ConfigMixin):
config_name = "model_index.json"
def register_modules(self, **kwargs):

View File

@@ -198,17 +198,22 @@ class StableDiffusionPipeline(DiffusionPipeline):
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# get the initial random noise unless the user supplied it
# Unlike in other pipelines, latents need to be generated in the target device
# for 1-to-1 results reproducibility with the CompVis implementation.
# However this currently doesn't work in `mps`.
latents_device = "cpu" if self.device.type == "mps" else self.device
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None:
latents = torch.randn(
latents_shape,
generator=generator,
device=self.device,
device=latents_device,
)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
latents = latents.to(self.device)
latents = latents.to(self.device)
# set timesteps
accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())

View File

@@ -355,7 +355,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> torch.Tensor:
# mps requires indices to be in the same device, so we use cpu as is the default with cuda
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5

View File

@@ -139,7 +139,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
elif tensor_format == "pt":
return torch.where(
timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device)
timesteps == 0,
torch.zeros_like(t.to(timesteps.device)),
self.discrete_sigmas[timesteps - 1].to(timesteps.device),
)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
@@ -196,8 +198,11 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
) # torch.repeat_interleave(timestep, sample.shape[0])
timesteps = (timestep * (len(self.timesteps) - 1)).long()
# mps requires indices to be in the same device, so we use cpu as is the default with cuda
timesteps = timesteps.to(self.discrete_sigmas.device)
sigma = self.discrete_sigmas[timesteps].to(sample.device)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep)
adjacent_sigma = self.get_adjacent_sigma(timesteps, timestep).to(sample.device)
drift = self.zeros_like(sample)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5

View File

@@ -8,6 +8,7 @@ import torch
global_rng = random.Random()
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
torch_device = "mps" if torch.backends.mps.is_available() else torch_device
def parse_flag_from_env(key, default=False):

View File

@@ -15,11 +15,13 @@
import inspect
import tempfile
import unittest
from typing import Dict, List, Tuple
import numpy as np
import torch
from diffusers.modeling_utils import ModelMixin
from diffusers.testing_utils import torch_device
from diffusers.training_utils import EMAModel
@@ -38,6 +40,11 @@ class ModelTesterMixin:
new_model.to(torch_device)
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
_ = model(**self.dummy_input)
_ = new_model(**self.dummy_input)
image = model(**inputs_dict)
if isinstance(image, dict):
image = image.sample
@@ -55,7 +62,12 @@ class ModelTesterMixin:
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample
@@ -132,6 +144,7 @@ class ModelTesterMixin:
self.assertEqual(output_1.shape, output_2.shape)
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
def test_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -147,6 +160,7 @@ class ModelTesterMixin:
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
@unittest.skipIf(torch_device == "mps", "Training is not supported in mps")
def test_ema_training(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -167,8 +181,13 @@ class ModelTesterMixin:
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
# Temporary fallback until `aten::_index_put_impl_` is implemented in mps
# Track progress in https://github.com/pytorch/pytorch/issues/77764
device = t.device
if device.type == "mps":
t = t.to("cpu")
t[t != t] = 0
return t
return t.to(device)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
@@ -198,7 +217,12 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)

View File

@@ -191,7 +191,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels = 3
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
return {"sample": noise, "timestep": time_step}

View File

@@ -18,6 +18,7 @@ import unittest
import torch
from diffusers import AutoencoderKL
from diffusers.modeling_utils import ModelMixin
from diffusers.testing_utils import floats_tensor, torch_device
from .test_modeling_common import ModelTesterMixin
@@ -80,6 +81,13 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
model = model.to(torch_device)
model.eval()
# One-time warmup pass (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
_ = model(image, sample_posterior=True).sample
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)

View File

@@ -85,6 +85,9 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = model(image)
output = model(image).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu()

View File

@@ -194,6 +194,10 @@ class PipelineFastTests(unittest.TestCase):
ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
_ = ddpm(num_inference_steps=1)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
@@ -207,8 +211,9 @@ class PipelineFastTests(unittest.TestCase):
expected_slice = np.array(
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
tolerance = 1e-2 if torch_device != "mps" else 3e-2
assert np.abs(image_slice.flatten() - expected_slice).max() < tolerance
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < tolerance
def test_pndm_cifar10(self):
unet = self.dummy_uncond_unet
@@ -244,6 +249,14 @@ class PipelineFastTests(unittest.TestCase):
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy")[
"sample"
]
generator = torch.manual_seed(0)
image = ldm([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy")[
"sample"
@@ -473,6 +486,11 @@ class PipelineFastTests(unittest.TestCase):
ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
# Warmup pass when using mps (see #372)
if torch_device == "mps":
generator = torch.manual_seed(0)
_ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images