mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
mps: remove warmup passes (#2771)
* Remove warmup passes in mps tests. * Update mps docs: no warmup pass in PyTorch 2 * Update imports.
This commit is contained in:
@@ -19,17 +19,22 @@ specific language governing permissions and limitations under the License.
|
||||
- Mac computer with Apple silicon (M1/M2) hardware.
|
||||
- macOS 12.6 or later (13.0 or later recommended).
|
||||
- arm64 version of Python.
|
||||
- PyTorch 1.13. You can install it with `pip` or `conda` using the instructions in https://pytorch.org/get-started/locally/.
|
||||
- PyTorch 2.0 (recommended) or 1.13 (minimum version supported for `mps`). You can install it with `pip` or `conda` using the instructions in https://pytorch.org/get-started/locally/.
|
||||
|
||||
|
||||
## Inference Pipeline
|
||||
|
||||
The snippet below demonstrates how to use the `mps` backend using the familiar `to()` interface to move the Stable Diffusion pipeline to your M1 or M2 device.
|
||||
|
||||
We recommend to "prime" the pipeline using an additional one-time pass through it. This is a temporary workaround for a weird issue we have detected: the first inference pass produces slightly different results than subsequent ones. You only need to do this pass once, and it's ok to use just one inference step and discard the result.
|
||||
<Tip warning={true}>
|
||||
|
||||
**If you are using PyTorch 1.13** you need to "prime" the pipeline using an additional one-time pass through it. This is a temporary workaround for a weird issue we detected: the first inference pass produces slightly different results than subsequent ones. You only need to do this pass once, and it's ok to use just one inference step and discard the result.
|
||||
|
||||
</Tip>
|
||||
|
||||
We strongly recommend you use PyTorch 2 or better, as it solves a number of problems like the one described in the previous tip.
|
||||
|
||||
```python
|
||||
# make sure you're logged in with `huggingface-cli login`
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
@@ -40,7 +45,7 @@ pipe.enable_attention_slicing()
|
||||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
|
||||
# First-time "warmup" pass (see explanation above)
|
||||
# First-time "warmup" pass if PyTorch version is 1.13 (see explanation above)
|
||||
_ = pipe(prompt, num_inference_steps=1)
|
||||
|
||||
# Results match those from the CPU device after the warmup pass.
|
||||
@@ -59,5 +64,4 @@ pipeline.enable_attention_slicing()
|
||||
|
||||
## Known Issues
|
||||
|
||||
- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372).
|
||||
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). This is being resolved, but for now we recommend to iterate instead of batching.
|
||||
|
||||
@@ -20,7 +20,6 @@ import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
@@ -124,12 +123,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
||||
model = model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# One-time warmup pass (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
image = image.to(torch_device)
|
||||
with torch.no_grad():
|
||||
_ = model(image, sample_posterior=True).sample
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
else:
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
|
||||
@@ -85,9 +85,6 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
image = image.to(torch_device)
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = model(image)
|
||||
output = model(image).sample
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
|
||||
|
||||
@@ -74,10 +74,6 @@ class DDPMPipelineFastTests(unittest.TestCase):
|
||||
ddpm.to(torch_device)
|
||||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = ddpm(num_inference_steps=1)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
|
||||
@@ -79,11 +79,6 @@ class LDMPipelineFastTests(unittest.TestCase):
|
||||
ldm.to(torch_device)
|
||||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
generator = torch.manual_seed(0)
|
||||
_ = ldm(generator=generator, num_inference_steps=1, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
|
||||
@@ -265,10 +265,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.Te
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = pipe(**self.get_dummy_inputs(torch_device))
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0]
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ import requests_mock
|
||||
import torch
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import ModelMixin, UNet2DConditionModel
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import torch_device
|
||||
@@ -119,11 +119,6 @@ class ModelTesterMixin:
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
_ = model(**self.dummy_input)
|
||||
_ = new_model(**self.dummy_input)
|
||||
|
||||
image = model(**inputs_dict)
|
||||
if isinstance(image, dict):
|
||||
image = image.sample
|
||||
@@ -161,11 +156,6 @@ class ModelTesterMixin:
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
_ = model(**self.dummy_input)
|
||||
_ = new_model(**self.dummy_input)
|
||||
|
||||
image = model(**inputs_dict)
|
||||
if isinstance(image, dict):
|
||||
image = image.sample
|
||||
@@ -203,10 +193,6 @@ class ModelTesterMixin:
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
model(**self.dummy_input)
|
||||
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
@@ -377,10 +363,6 @@ class ModelTesterMixin:
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
model(**self.dummy_input)
|
||||
|
||||
outputs_dict = model(**inputs_dict)
|
||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
|
||||
|
||||
@@ -121,10 +121,6 @@ class PipelineTesterMixin:
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = pipe(**self.get_dummy_inputs(torch_device))
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
@@ -327,10 +323,6 @@ class PipelineTesterMixin:
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = pipe(**self.get_dummy_inputs(torch_device))
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0]
|
||||
|
||||
@@ -402,10 +394,6 @@ class PipelineTesterMixin:
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = pipe(**self.get_dummy_inputs(torch_device))
|
||||
|
||||
# set all optional components to None
|
||||
for optional_component in pipe._optional_components:
|
||||
setattr(pipe, optional_component, None)
|
||||
@@ -477,10 +465,6 @@ class PipelineTesterMixin:
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps":
|
||||
_ = pipe(**self.get_dummy_inputs(torch_device))
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user