1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Use torch.device instead of current device index for BnB quantizer (#10069)

* update

* apply review suggestion

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Aryan
2024-12-05 08:05:24 +05:30
committed by GitHub
parent 0d11ab26c4
commit 98d0cd5778
2 changed files with 3 additions and 1 deletions

View File

@@ -176,6 +176,8 @@ def load_model_dict_into_meta(
hf_quantizer=None,
keep_in_fp32_modules=None,
) -> List[str]:
if device is not None and not isinstance(device, (str, torch.device)):
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
if hf_quantizer is None:
device = device or torch.device("cpu")
dtype = dtype or torch.float32

View File

@@ -836,7 +836,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
param_device = "cpu"
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
elif is_quant_method_bnb:
param_device = torch.cuda.current_device()
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)