mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -134,19 +134,6 @@ def _fetch_remapped_cls_from_config(config, old_class):
|
||||
return old_class
|
||||
|
||||
|
||||
def _check_archive_and_maybe_raise_error(checkpoint_file, format_list):
|
||||
"""
|
||||
Check format of the archive
|
||||
"""
|
||||
with safetensors.safe_open(checkpoint_file, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
if metadata is not None and metadata.get("format") not in format_list:
|
||||
raise OSError(
|
||||
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
|
||||
"you save your model with the `save_pretrained` method."
|
||||
)
|
||||
|
||||
|
||||
def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
|
||||
"""
|
||||
Find the device of param_name from the device_map.
|
||||
@@ -183,7 +170,6 @@ def load_state_dict(
|
||||
# tensors are loaded on cpu
|
||||
with dduf_entries[checkpoint_file].as_mmap() as mm:
|
||||
return safetensors.torch.load(mm)
|
||||
_check_archive_and_maybe_raise_error(checkpoint_file, format_list=["pt", "flax"])
|
||||
if disable_mmap:
|
||||
return safetensors.torch.load(open(checkpoint_file, "rb").read())
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user