1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[tests] model-level device_map clarifications (#11681)

* add clarity in documentation for device_map

* docs

* fix how compiler tester mixins are used.

* propagate

* more

* typo.

* fix tests

* fix order of decroators.

* clarify more.

* more test cases.

* fix doc

* fix device_map docstring in pipeline_utils.

* more examples

* more

* update

* remove code for stuff that is already supported.

* fix stuff.
This commit is contained in:
Sayak Paul
2025-06-11 22:41:59 +05:30
committed by GitHub
parent b6f7933044
commit 91545666e0
3 changed files with 74 additions and 11 deletions

View File

@@ -814,14 +814,43 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
device_map (`Union[int, str, torch.device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be defined for each
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
Examples:
```py
>>> from diffusers import AutoModel
>>> import torch
>>> # This works.
>>> model = AutoModel.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
... )
>>> # This also works (integer accelerator device ID).
>>> model = AutoModel.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
... )
>>> # Specifying a supported offloading strategy like "auto" also works.
>>> model = AutoModel.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
... )
>>> # Specifying a dictionary as `device_map` also works.
>>> model = AutoModel.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0",
... subfolder="unet",
... device_map={"": torch.device("cuda")},
... )
```
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
map](https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap). You
can also refer to the [Diffusers-specific
documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding)
for more concrete examples.
max_memory (`Dict`, *optional*):
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
each GPU and the available CPU RAM if unset.
@@ -1387,7 +1416,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
low_cpu_mem_usage: bool = True,
dtype: Optional[Union[str, torch.dtype]] = None,
keep_in_fp32_modules: Optional[List[str]] = None,
device_map: Dict[str, Union[int, str, torch.device]] = None,
device_map: Union[str, int, torch.device, Dict[str, Union[int, str, torch.device]]] = None,
offload_state_dict: Optional[bool] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,

View File

@@ -669,14 +669,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Mirror source to resolve accessibility issues if youre downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesnt need to be defined for each
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
same device.
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
device_map (`str`, *optional*):
Strategy that dictates how the different components of a pipeline should be placed on available
devices. Currently, only "balanced" `device_map` is supported. Check out
[this](https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement)
to know more.
max_memory (`Dict`, *optional*):
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
each GPU and the available CPU RAM if unset.

View File

@@ -46,6 +46,7 @@ from diffusers.utils.testing_utils import (
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_gpu,
skip_mps,
slow,
torch_all_close,
@@ -1083,6 +1084,42 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@parameterized.expand(
[
(-1, "You can't pass device_map as a negative int"),
("foo", "When passing device_map as a string, the value needs to be a device name"),
]
)
def test_wrong_device_map_raises_error(self, device_map, msg_substring):
with self.assertRaises(ValueError) as err_ctx:
_ = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
)
assert msg_substring in str(err_ctx.exception)
@parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")])
@require_torch_gpu
def test_passing_non_dict_device_map_works(self, device_map):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
)
output = loaded_model(**inputs_dict)
assert output.sample.shape == (4, 4, 16, 16)
@parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
@require_torch_gpu
def test_passing_dict_device_map_works(self, name, device_map):
# There are other valid dict-based `device_map` values too. It's best to refer to
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map={name: device_map}
)
output = loaded_model(**inputs_dict)
assert output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()