mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
quick fix
This commit is contained in:
@@ -1185,8 +1185,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
state_dict = None
|
||||
if not is_sharded:
|
||||
map_location = "cpu"
|
||||
if (
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
# Time to load the checkpoint
|
||||
state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries)
|
||||
state_dict = load_state_dict(resolved_model_file[0], disable_mmap=disable_mmap, dduf_entries=dduf_entries, map_location=map_location)
|
||||
# We only fix it for non sharded checkpoints as we don't need it yet for sharded one.
|
||||
model._fix_state_dict_keys_on_load(state_dict)
|
||||
|
||||
@@ -1443,10 +1451,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and hf_quantizer.quantfization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
|
||||
for shard_file in resolved_model_file:
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, map_location=map_location)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user