diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 55ce0cf79f..a6b00e8d11 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1438,8 +1438,17 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): if len(resolved_model_file) > 1: resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards") + map_location = "cpu" + if ( + device_map is not None + and hf_quantizer is not None + and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO + and hf_quantizer.quantfization_config.quant_type in ["int4_weight_only", "autoquant"] + ): + map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) + for shard_file in resolved_model_file: - state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries) + state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, map_location=map_location) def _find_mismatched_keys( state_dict,