From 98d0cd5778afef0f8361908ed613ebcc285c1581 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 5 Dec 2024 08:05:24 +0530 Subject: [PATCH] Use torch.device instead of current device index for BnB quantizer (#10069) * update * apply review suggestion --------- Co-authored-by: Sayak Paul --- src/diffusers/models/model_loading_utils.py | 2 ++ src/diffusers/models/modeling_utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 932a945711..751117f8f2 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -176,6 +176,8 @@ def load_model_dict_into_meta( hf_quantizer=None, keep_in_fp32_modules=None, ) -> List[str]: + if device is not None and not isinstance(device, (str, torch.device)): + raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.") if hf_quantizer is None: device = device or torch.device("cpu") dtype = dtype or torch.float32 diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 76f6c5f630..7b2022798d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -836,7 +836,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): param_device = "cpu" # TODO (sayakpaul, SunMarc): remove this after model loading refactor elif is_quant_method_bnb: - param_device = torch.cuda.current_device() + param_device = torch.device(torch.cuda.current_device()) state_dict = load_state_dict(model_file, variant=variant) model._convert_deprecated_attention_blocks(state_dict)