1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[Core] add "balanced" device_map support to pipelines (#6857)

* get device <-> component mapping when using multiple gpus.

* condition the device_map bits.

* relax condition

* device_map progress.

* device_map enhancement

* some cleaning up and debugging

* Apply suggestions from code review

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* incorporate suggestions from PR.

* remove multi-gpu condition for now.

* guard check the component -> device mapping

* fix: device_memory variable

* dispatching transformers model to have force_hooks=True

* better guarding for transformers device_map

* introduce support balanced_low_memory and balanced_ultra_low_memory.

* remove device_map patch.

* fix: intermediate variable scoping.

* fix: condition in cpu offload.

* fix: flax class restrictions.

* remove modifications from cpu_offload and model_offload

* incorporate changes.

* add a simple forward pass test

* add: torch_device in get_inputs()

* add: tests

* remove print

* safe-guard to(), model offloading and cpu offloading when balanced is used as a device_map.

* style

* remove .

* safeguard device_map with more checks and remove invalid device_mapping strategues.

* make  a class attribute and adjust tests accordingly.

* fix device_map check

* fix test

* adjust comment

* fix: device_map attribute

* fix: dispatching.

* max_memory test for pipeline

* version guard the tests

* fix guard.

* address review feedback.

* reset_device_map method.

* add: test for reset_hf_device_map

* fix a couple things.

* add reset_device_map() in the error message.

* add tests for checking reset_device_map doesn't have unintended consequences.

* fix reset_device_map and offloading tests.

* create _get_final_device_map utility.

* hf_device_map -> _hf_device_map

* add documentation

* add notes suggested by Marc.

* styling.

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* move updates within gpu condition.

* other docs related things

* note on ignore a device not specified in .

* provide a suggestion if device mapping errors out.

* fix: typo.

* _hf_device_map -> hf_device_map

* Empty-Commit

* add: example hf_device_map.

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Sayak Paul
2024-04-10 08:59:05 +05:30
committed by GitHub
parent c827e94da0
commit 3e4a6bd2d4
7 changed files with 546 additions and 17 deletions

View File

@@ -699,6 +699,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
force_hooks=True,
)
except AttributeError as e:
# When using accelerate loading, we do not have the ability to load the state

View File

@@ -22,15 +22,19 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
from huggingface_hub import (
model_info,
)
from huggingface_hub import model_info
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version
from .. import __version__
from ..utils import (
FLAX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
get_class_from_dynamic_module,
is_accelerate_available,
is_peft_available,
is_transformers_available,
logging,
@@ -44,9 +48,12 @@ if is_transformers_available():
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
from huggingface_hub.utils import validate_hf_hub_args
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
if is_accelerate_available():
import accelerate
from accelerate import dispatch_model
from accelerate.hooks import remove_hook_from_module
from accelerate.utils import compute_module_sizes, get_max_memory
INDEX_FILE = "diffusion_pytorch_model.bin"
@@ -376,6 +383,207 @@ def _get_pipeline_class(
return pipeline_cls
def _load_empty_model(
library_name: str,
class_name: str,
importable_classes: List[Any],
pipelines: Any,
is_pipeline_module: bool,
name: str,
torch_dtype: Union[str, torch.dtype],
cached_folder: Union[str, os.PathLike],
**kwargs,
):
# retrieve class objects.
class_obj, _ = get_class_obj_and_candidates(
library_name,
class_name,
importable_classes,
pipelines,
is_pipeline_module,
component_name=name,
cache_dir=cached_folder,
)
if is_transformers_available():
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
else:
transformers_version = "N/A"
# Determine library.
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and transformers_version >= version.parse("4.20.0")
)
diffusers_module = importlib.import_module(__name__.split(".")[0])
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
model = None
config_path = cached_folder
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
if is_diffusers_model:
# Load config and then the model on meta.
config, unused_kwargs, commit_hash = class_obj.load_config(
os.path.join(config_path, name),
cache_dir=cached_folder,
return_unused_kwargs=True,
return_commit_hash=True,
force_download=kwargs.pop("force_download", False),
resume_download=kwargs.pop("resume_download", False),
proxies=kwargs.pop("proxies", None),
local_files_only=kwargs.pop("local_files_only", False),
token=kwargs.pop("token", None),
revision=kwargs.pop("revision", None),
subfolder=kwargs.pop("subfolder", None),
user_agent=user_agent,
)
with accelerate.init_empty_weights():
model = class_obj.from_config(config, **unused_kwargs)
elif is_transformers_model:
config_class = getattr(class_obj, "config_class", None)
if config_class is None:
raise ValueError("`config_class` cannot be None. Please double-check the model.")
config = config_class.from_pretrained(
cached_folder,
subfolder=name,
force_download=kwargs.pop("force_download", False),
resume_download=kwargs.pop("resume_download", False),
proxies=kwargs.pop("proxies", None),
local_files_only=kwargs.pop("local_files_only", False),
token=kwargs.pop("token", None),
revision=kwargs.pop("revision", None),
user_agent=user_agent,
)
with accelerate.init_empty_weights():
model = class_obj(config)
if model is not None:
model = model.to(dtype=torch_dtype)
return model
def _assign_components_to_devices(
module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced"
):
device_ids = list(device_memory.keys())
device_cycle = device_ids + device_ids[::-1]
device_memory = device_memory.copy()
device_id_component_mapping = {}
current_device_index = 0
for component in module_sizes:
device_id = device_cycle[current_device_index % len(device_cycle)]
component_memory = module_sizes[component]
curr_device_memory = device_memory[device_id]
# If the GPU doesn't fit the current component offload to the CPU.
if component_memory > curr_device_memory:
device_id_component_mapping["cpu"] = [component]
else:
if device_id not in device_id_component_mapping:
device_id_component_mapping[device_id] = [component]
else:
device_id_component_mapping[device_id].append(component)
# Update the device memory.
device_memory[device_id] -= component_memory
current_device_index += 1
return device_id_component_mapping
def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs):
# To avoid circular import problem.
from diffusers import pipelines
torch_dtype = kwargs.get("torch_dtype", torch.float32)
# Load each module in the pipeline on a meta device so that we can derive the device map.
init_empty_modules = {}
for name, (library_name, class_name) in init_dict.items():
if class_name.startswith("Flax"):
raise ValueError("Flax pipelines are not supported with `device_map`.")
# Define all importable classes
is_pipeline_module = hasattr(pipelines, library_name)
importable_classes = ALL_IMPORTABLE_CLASSES
loaded_sub_model = None
# Use passed sub model or load class_name from library_name
if name in passed_class_obj:
# if the model is in a pipeline module, then we load it from the pipeline
# check that passed_class_obj has correct parent class
maybe_raise_or_warn(
library_name,
library,
class_name,
importable_classes,
passed_class_obj,
name,
is_pipeline_module,
)
with accelerate.init_empty_weights():
loaded_sub_model = passed_class_obj[name]
else:
loaded_sub_model = _load_empty_model(
library_name=library_name,
class_name=class_name,
importable_classes=importable_classes,
pipelines=pipelines,
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
name=name,
torch_dtype=torch_dtype,
cached_folder=kwargs.get("cached_folder", None),
force_download=kwargs.get("force_download", None),
resume_download=kwargs.get("resume_download", None),
proxies=kwargs.get("proxies", None),
local_files_only=kwargs.get("local_files_only", None),
token=kwargs.get("token", None),
revision=kwargs.get("revision", None),
)
if loaded_sub_model is not None:
init_empty_modules[name] = loaded_sub_model
# determine device map
# Obtain a sorted dictionary for mapping the model-level components
# to their sizes.
module_sizes = {
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
for module_name, module in init_empty_modules.items()
if isinstance(module, torch.nn.Module)
}
module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True))
# Obtain maximum memory available per device (GPUs only).
max_memory = get_max_memory(max_memory)
max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True))
max_memory = {k: v for k, v in max_memory.items() if k != "cpu"}
# Obtain a dictionary mapping the model-level components to the available
# devices based on the maximum memory and the model sizes.
device_id_component_mapping = _assign_components_to_devices(
module_sizes, max_memory, device_mapping_strategy=device_map
)
# Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}`
final_device_map = {}
for device_id, components in device_id_component_mapping.items():
for component in components:
final_device_map[component] = device_id
return final_device_map
def load_sub_model(
library_name: str,
class_name: str,
@@ -493,6 +701,22 @@ def load_sub_model(
# else load from the root directory
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
# remove hooks
remove_hook_from_module(loaded_sub_model, recurse=True)
needs_offloading_to_cpu = device_map[""] == "cpu"
if needs_offloading_to_cpu:
dispatch_model(
loaded_sub_model,
state_dict=loaded_sub_model.state_dict(),
device_map=device_map,
force_hooks=True,
main_device=0,
)
else:
dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
return loaded_sub_model

View File

@@ -73,6 +73,7 @@ from .pipeline_loading_utils import (
LOADABLE_CLASSES,
_fetch_class_library_tuple,
_get_custom_pipeline_class,
_get_final_device_map,
_get_pipeline_class,
_unwrap_model,
is_safetensors_compatible,
@@ -91,6 +92,8 @@ LIBRARIES = []
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)
SUPPORTED_DEVICE_MAP = ["balanced"]
logger = logging.get_logger(__name__)
@@ -141,6 +144,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
config_name = "model_index.json"
model_cpu_offload_seq = None
hf_device_map = None
_optional_components = []
_exclude_from_cpu_offload = []
_load_connected_pipes = False
@@ -389,6 +393,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
)
# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
@@ -642,18 +652,35 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
if device_map is not None and not is_accelerate_available():
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
"Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
)
if device_map is not None and not isinstance(device_map, str):
raise ValueError("`device_map` must be a string.")
if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
raise NotImplementedError(
f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
)
if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
if is_accelerate_version("<", "0.28.0"):
raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
@@ -729,6 +756,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
revision=custom_revision,
)
if device_map is not None and pipeline_class._load_connected_pipes:
raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
# DEPRECATED: To be removed in 1.0.0
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
version.parse(config_dict["_diffusers_version"]).base_version
@@ -795,17 +825,45 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# import it here to avoid circular import
from diffusers import pipelines
# 6. Load each module in the pipeline
# 6. device map delegation
final_device_map = None
if device_map is not None:
final_device_map = _get_final_device_map(
device_map=device_map,
pipeline_class=pipeline_class,
passed_class_obj=passed_class_obj,
init_dict=init_dict,
library=library,
max_memory=max_memory,
torch_dtype=torch_dtype,
cached_folder=cached_folder,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
)
# 7. Load each module in the pipeline
current_device_map = None
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if final_device_map is not None and len(final_device_map) > 0:
component_device = final_device_map.get(name, None)
if component_device is not None:
current_device_map = {"": component_device}
else:
current_device_map = None
# 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
# 6.2 Define all importable classes
# 7.2 Define all importable classes
is_pipeline_module = hasattr(pipelines, library_name)
importable_classes = ALL_IMPORTABLE_CLASSES
loaded_sub_model = None
# 6.3 Use passed sub model or load class_name from library_name
# 7.3 Use passed sub model or load class_name from library_name
if name in passed_class_obj:
# if the model is in a pipeline module, then we load it from the pipeline
# check that passed_class_obj has correct parent class
@@ -826,7 +884,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
torch_dtype=torch_dtype,
provider=provider,
sess_options=sess_options,
device_map=device_map,
device_map=current_device_map,
max_memory=max_memory,
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
@@ -893,7 +951,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
)
# 7. Potentially add passed objects if expected
# 8. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
@@ -906,11 +964,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
# 8. Instantiate the pipeline
# 10. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
# 9. Save where the model was instantiated from
# 11. Save where the model was instantiated from
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
if device_map is not None:
setattr(model, "hf_device_map", final_device_map)
return model
@property
@@ -963,6 +1023,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda".
"""
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
)
if self.model_cpu_offload_seq is None:
raise ValueError(
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
@@ -1056,6 +1122,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
self.remove_all_hooks()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
)
torch_device = torch.device(device)
device_index = torch_device.index
@@ -1090,6 +1162,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
offload_buffers = len(model._parameters) > 0
cpu_offload(model, device, offload_buffers=offload_buffers)
def reset_device_map(self):
r"""
Resets the device maps (if any) to None.
"""
if self.hf_device_map is None:
return
else:
self.remove_all_hooks()
for name, component in self.components.items():
if isinstance(component, torch.nn.Module):
component.to("cpu")
self.hf_device_map = None
@classmethod
@validate_hf_hub_args
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:

View File

@@ -255,6 +255,20 @@ def require_torch_accelerator(test_case):
)
def require_torch_multi_gpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
multiple GPUs. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests
-k "multi_gpu"
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
import torch
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
def require_torch_accelerator_with_fp16(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
@@ -343,6 +357,18 @@ def require_peft_version_greater(peft_version):
return decorator
def require_accelerate_version_greater(accelerate_version):
def decorator(test_case):
correct_accelerate_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("accelerate")).base_version
) > version.parse(accelerate_version)
return unittest.skipUnless(
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend