mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
support hf_quantizer in cache warmup. (#12043)
* support hf_quantizer in cache warmup. * reviewer feedback * up * up
This commit is contained in:
@@ -17,7 +17,6 @@
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
from array import array
|
||||
from collections import OrderedDict, defaultdict
|
||||
@@ -717,27 +716,33 @@ def _expand_device_map(device_map, param_names):
|
||||
|
||||
|
||||
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
|
||||
def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
|
||||
def _caching_allocator_warmup(
|
||||
model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
|
||||
) -> None:
|
||||
"""
|
||||
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
||||
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
|
||||
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
|
||||
very large margin.
|
||||
"""
|
||||
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
|
||||
# Remove disk and cpu devices, and cast to proper torch.device
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device)
|
||||
for param, device in expanded_device_map.items()
|
||||
if str(device) not in ["cpu", "disk"]
|
||||
}
|
||||
parameter_count = defaultdict(lambda: 0)
|
||||
total_byte_count = defaultdict(lambda: 0)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
try:
|
||||
param = model.get_parameter(param_name)
|
||||
except AttributeError:
|
||||
param = model.get_buffer(param_name)
|
||||
parameter_count[device] += math.prod(param.shape)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = param.numel() * param.element_size()
|
||||
# TODO: account for TP when needed.
|
||||
total_byte_count[device] += param_byte_count
|
||||
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, param_count in parameter_count.items():
|
||||
_ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
|
||||
for device, byte_count in total_byte_count.items():
|
||||
_ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)
|
||||
|
||||
@@ -1532,10 +1532,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# tensors using their expected shape and not performing any initialization of the memory (empty data).
|
||||
# When the actual device allocations happen, the allocator already has a pool of unused device memory
|
||||
# that it can re-use for faster loading of the model.
|
||||
# TODO: add support for warmup with hf_quantizer
|
||||
if device_map is not None and hf_quantizer is None:
|
||||
if device_map is not None:
|
||||
expanded_device_map = _expand_device_map(device_map, expected_keys)
|
||||
_caching_allocator_warmup(model, expanded_device_map, dtype)
|
||||
_caching_allocator_warmup(model, expanded_device_map, dtype, hf_quantizer)
|
||||
|
||||
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
|
||||
state_dict_folder, state_dict_index = None, None
|
||||
|
||||
@@ -209,6 +209,17 @@ class DiffusersQuantizer(ABC):
|
||||
|
||||
return model
|
||||
|
||||
def get_cuda_warm_up_factor(self):
|
||||
"""
|
||||
The factor to be used in `caching_allocator_warmup` to get the number of bytes to pre-allocate to warm up cuda.
|
||||
A factor of 2 means we allocate all bytes in the empty model (since we allocate in fp16), a factor of 4 means
|
||||
we allocate half the memory of the weights residing in the empty model, etc...
|
||||
"""
|
||||
# By default we return 4, i.e. half the model size (this corresponds to the case where the model is not
|
||||
# really pre-processed, i.e. we do not have the info that weights are going to be 8 bits before actual
|
||||
# weight loading)
|
||||
return 4
|
||||
|
||||
def _dequantize(self, model):
|
||||
raise NotImplementedError(
|
||||
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
|
||||
|
||||
@@ -19,6 +19,7 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
|
||||
|
||||
import importlib
|
||||
import types
|
||||
from fnmatch import fnmatch
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from packaging import version
|
||||
@@ -278,6 +279,31 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
||||
|
||||
def get_cuda_warm_up_factor(self):
|
||||
"""
|
||||
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup.
|
||||
- A factor of 2 means we pre-allocate the full memory footprint of the model.
|
||||
- A factor of 4 means we pre-allocate half of that, and so on
|
||||
|
||||
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give
|
||||
the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents
|
||||
quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the
|
||||
torch_dtype not the actual bit-width of the quantized data.
|
||||
|
||||
To correct for this:
|
||||
- Use a division factor of 8 for int4 weights
|
||||
- Use a division factor of 4 for int8 weights
|
||||
"""
|
||||
# Original mapping for non-AOBaseConfig types
|
||||
# For the uint types, this is a best guess. Once these types become more used
|
||||
# we can look into their nuances.
|
||||
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
|
||||
quant_type = self.quantization_config.quant_type
|
||||
for pattern, target_dtype in map_to_target_dtype.items():
|
||||
if fnmatch(quant_type, pattern):
|
||||
return target_dtype
|
||||
raise ValueError(f"Unsupported quant_type: {quant_type!r}")
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "ModelMixin",
|
||||
|
||||
Reference in New Issue
Block a user