mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Use torch.device instead of current device index for BnB quantizer (#10069)
* update * apply review suggestion --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -176,6 +176,8 @@ def load_model_dict_into_meta(
|
||||
hf_quantizer=None,
|
||||
keep_in_fp32_modules=None,
|
||||
) -> List[str]:
|
||||
if device is not None and not isinstance(device, (str, torch.device)):
|
||||
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
|
||||
if hf_quantizer is None:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
@@ -836,7 +836,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
param_device = "cpu"
|
||||
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
||||
elif is_quant_method_bnb:
|
||||
param_device = torch.cuda.current_device()
|
||||
param_device = torch.device(torch.cuda.current_device())
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user