mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Core] separate the loading utilities in modeling similar to pipelines. (#7943)
separate the loading utilities in modeling similar to pipelines.
This commit is contained in:
149
src/diffusers/models/model_loading_utils.py
Normal file
149
src/diffusers/models/model_loading_utils.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from ..utils import (
|
||||
SAFETENSORS_FILE_EXTENSION,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import infer_auto_device_map
|
||||
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
|
||||
|
||||
|
||||
# Adapted from `transformers` (see modeling_utils.py)
|
||||
def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
|
||||
if isinstance(device_map, str):
|
||||
no_split_modules = model._get_no_split_modules(device_map)
|
||||
device_map_kwargs = {"no_split_module_classes": no_split_modules}
|
||||
|
||||
if device_map != "sequential":
|
||||
max_memory = get_balanced_memory(
|
||||
model,
|
||||
dtype=torch_dtype,
|
||||
low_zero=(device_map == "balanced_low_0"),
|
||||
max_memory=max_memory,
|
||||
**device_map_kwargs,
|
||||
)
|
||||
else:
|
||||
max_memory = get_max_memory(max_memory)
|
||||
|
||||
device_map_kwargs["max_memory"] = max_memory
|
||||
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
|
||||
|
||||
return device_map
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
try:
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
else:
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
||||
return torch.load(
|
||||
checkpoint_file,
|
||||
map_location="cpu",
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||
"you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
||||
"model. Make sure you have saved the model properly."
|
||||
) from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
|
||||
)
|
||||
|
||||
|
||||
def load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict: OrderedDict,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
model_name_or_path: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
||||
|
||||
unexpected_keys = []
|
||||
empty_state_dict = model.state_dict()
|
||||
for param_name, param in state_dict.items():
|
||||
if param_name not in empty_state_dict:
|
||||
unexpected_keys.append(param_name)
|
||||
continue
|
||||
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
||||
raise ValueError(
|
||||
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param)
|
||||
return unexpected_keys
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
state_dict = state_dict.copy()
|
||||
error_msgs = []
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: torch.nn.Module, prefix: str = ""):
|
||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(model_to_load)
|
||||
|
||||
return error_msgs
|
||||
@@ -33,7 +33,6 @@ from .. import __version__
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
SAFETENSORS_FILE_EXTENSION,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
_add_variant,
|
||||
@@ -44,6 +43,12 @@ from ..utils import (
|
||||
logging,
|
||||
)
|
||||
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
|
||||
from .model_loading_utils import (
|
||||
_determine_device_map,
|
||||
_load_state_dict_into_model,
|
||||
load_model_dict_into_meta,
|
||||
load_state_dict,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -57,9 +62,6 @@ else:
|
||||
|
||||
if is_accelerate_available():
|
||||
import accelerate
|
||||
from accelerate import infer_auto_device_map
|
||||
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
|
||||
from accelerate.utils.versions import is_torch_version
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
||||
@@ -100,117 +102,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
||||
return first_tuple[1].dtype
|
||||
|
||||
|
||||
# Adapted from `transformers` (see modeling_utils.py)
|
||||
def _determine_device_map(model: "ModelMixin", device_map, max_memory, torch_dtype):
|
||||
if isinstance(device_map, str):
|
||||
no_split_modules = model._get_no_split_modules(device_map)
|
||||
device_map_kwargs = {"no_split_module_classes": no_split_modules}
|
||||
|
||||
if device_map != "sequential":
|
||||
max_memory = get_balanced_memory(
|
||||
model,
|
||||
dtype=torch_dtype,
|
||||
low_zero=(device_map == "balanced_low_0"),
|
||||
max_memory=max_memory,
|
||||
**device_map_kwargs,
|
||||
)
|
||||
else:
|
||||
max_memory = get_max_memory(max_memory)
|
||||
|
||||
device_map_kwargs["max_memory"] = max_memory
|
||||
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
|
||||
|
||||
return device_map
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
try:
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
else:
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
||||
return torch.load(
|
||||
checkpoint_file,
|
||||
map_location="cpu",
|
||||
**weights_only_kwarg,
|
||||
)
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||
"you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
||||
"model. Make sure you have saved the model properly."
|
||||
) from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
|
||||
)
|
||||
|
||||
|
||||
def load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict: OrderedDict,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
model_name_or_path: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
||||
|
||||
unexpected_keys = []
|
||||
empty_state_dict = model.state_dict()
|
||||
for param_name, param in state_dict.items():
|
||||
if param_name not in empty_state_dict:
|
||||
unexpected_keys.append(param_name)
|
||||
continue
|
||||
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
||||
raise ValueError(
|
||||
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param)
|
||||
return unexpected_keys
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
state_dict = state_dict.copy()
|
||||
error_msgs = []
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: torch.nn.Module, prefix: str = ""):
|
||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(model_to_load)
|
||||
|
||||
return error_msgs
|
||||
|
||||
|
||||
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
Reference in New Issue
Block a user