mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix: caching allocator behaviour for quantization. (#12172)
* fix: caching allocator behaviour for quantization. * up * Update src/diffusers/models/model_loading_utils.py Co-authored-by: Aryan <aryan@huggingface.co> --------- Co-authored-by: Aryan <aryan@huggingface.co>
This commit is contained in:
@@ -726,23 +726,29 @@ def _caching_allocator_warmup(
|
||||
very large margin.
|
||||
"""
|
||||
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
|
||||
# Remove disk and cpu devices, and cast to proper torch.device
|
||||
|
||||
# Keep only accelerator devices
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device)
|
||||
for param, device in expanded_device_map.items()
|
||||
if str(device) not in ["cpu", "disk"]
|
||||
}
|
||||
total_byte_count = defaultdict(lambda: 0)
|
||||
if not accelerator_device_map:
|
||||
return
|
||||
|
||||
elements_per_device = defaultdict(int)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
try:
|
||||
param = model.get_parameter(param_name)
|
||||
p = model.get_parameter(param_name)
|
||||
except AttributeError:
|
||||
param = model.get_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = param.numel() * param.element_size()
|
||||
try:
|
||||
p = model.get_buffer(param_name)
|
||||
except AttributeError:
|
||||
raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
|
||||
# TODO: account for TP when needed.
|
||||
total_byte_count[device] += param_byte_count
|
||||
elements_per_device[device] += p.numel()
|
||||
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, byte_count in total_byte_count.items():
|
||||
_ = torch.empty(byte_count // factor, dtype=dtype, device=device, requires_grad=False)
|
||||
for device, elem_count in elements_per_device.items():
|
||||
warmup_elems = max(1, elem_count // factor)
|
||||
_ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)
|
||||
|
||||
Reference in New Issue
Block a user