diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index ebc7d79aeb..65d008a8e8 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -18,7 +18,7 @@ import importlib import inspect import os from array import array -from collections import OrderedDict +from collections import OrderedDict, defaultdict from pathlib import Path from typing import Dict, List, Optional, Union from zipfile import is_zipfile @@ -38,6 +38,7 @@ from ..utils import ( _get_model_file, deprecate, is_accelerate_available, + is_accelerator_device, is_gguf_available, is_torch_available, is_torch_version, @@ -304,6 +305,51 @@ def load_model_dict_into_meta( return offload_index, state_dict_index +# Taken from +# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5852C1-L5861C26 +def _expand_device_map(device_map, param_names): + new_device_map = {} + for module, device in device_map.items(): + new_device_map.update( + {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} + ) + return new_device_map + + +# Adapted from https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5874 +# We don't incorporate the `tp_plan` stuff as we don't support it yet. +def _caching_allocator_warmup(model, device_map: Dict, factor=2) -> Dict: + # Remove disk, cpu and meta devices, and cast to proper torch.device + accelerator_device_map = { + param: torch.device(device) for param, device in device_map.items() if is_accelerator_device(device) + } + if not len(accelerator_device_map): + return + + total_byte_count = defaultdict(lambda: 0) + for param_name, device in accelerator_device_map.items(): + param = model.get_parameter_or_buffer(param_name) + # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules` + param_byte_count = param.numel() * param.element_size() + total_byte_count[device] += param_byte_count + + # This will kick off the caching allocator to avoid having to Malloc afterwards + for device, byte_count in total_byte_count.items(): + if device.type == "cuda": + index = device.index if device.index is not None else torch.cuda.current_device() + device_memory = torch.cuda.mem_get_info(index)[0] + # Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more + # than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large, + # and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all + # the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead + # to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details. + # Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much + # if using e.g. 90% of device size, while a 140GiB device would allocate too little + byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3))) + # Allocate memory + _ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False) + + def _load_state_dict_into_model( model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False ) -> List[str]: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 2a22bc09ad..51f3497836 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -63,7 +63,9 @@ from ..utils.hub_utils import ( populate_model_card, ) from .model_loading_utils import ( + _caching_allocator_warmup, _determine_device_map, + _expand_device_map, _fetch_index_file, _fetch_index_file_legacy, _load_state_dict_into_model, @@ -1374,6 +1376,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): else: return super().float(*args) + # Taken from `transformers`. + # https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5351C5-L5365C81 + def get_parameter_or_buffer(self, target: str): + """ + Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines + `get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a leaf + of the model. + """ + try: + return self.get_parameter(target) + except AttributeError: + pass + try: + return self.get_buffer(target) + except AttributeError: + pass + raise AttributeError(f"`{target}` is neither a parameter nor a buffer.") + @classmethod def _load_pretrained_model( cls, @@ -1410,6 +1430,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): assign_to_params_buffers = None error_msgs = [] + # Optionally, warmup cuda to load the weights much faster on devices + if device_map is not None: + expanded_device_map = _expand_device_map(device_map, expected_keys) + _caching_allocator_warmup(model, expanded_device_map, factor=2 if hf_quantizer is None else 4) + # Deal with offload if device_map is not None and "disk" in device_map.values(): if offload_folder is None: diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index ed89955ba5..d5934066d6 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -129,6 +129,7 @@ from .state_dict_utils import ( convert_unet_state_dict_to_peft, state_dict_all_zero, ) +from .testing_utils import is_accelerator_device from .typing_utils import _get_detailed_type, _is_valid_type diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 4ba6f7c25e..0a22ecaa7c 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1289,6 +1289,18 @@ if is_torch_available(): update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN") +if is_torch_available(): + # Taken from + # https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5864C1-L5871C64 + def is_accelerator_device(device: Union[str, int, torch.device]) -> bool: + """Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not + a proper `torch.device`. + """ + if device == "disk": + return False + else: + return torch.device(device).type not in ["meta", "cpu"] + # Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers/testing_utils.py#L3090 # Type definition of key used in `Expectations` class.