diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py new file mode 100644 index 0000000000..635cd0ba57 --- /dev/null +++ b/src/diffusers/models/model_loading_utils.py @@ -0,0 +1,149 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from collections import OrderedDict +from typing import List, Optional, Union + +import safetensors +import torch + +from ..utils import ( + SAFETENSORS_FILE_EXTENSION, + is_accelerate_available, + is_torch_version, + logging, +) + + +logger = logging.get_logger(__name__) + + +if is_accelerate_available(): + from accelerate import infer_auto_device_map + from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device + + +# Adapted from `transformers` (see modeling_utils.py) +def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype): + if isinstance(device_map, str): + no_split_modules = model._get_no_split_modules(device_map) + device_map_kwargs = {"no_split_module_classes": no_split_modules} + + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + dtype=torch_dtype, + low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, + **device_map_kwargs, + ) + else: + max_memory = get_max_memory(max_memory) + + device_map_kwargs["max_memory"] = max_memory + device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) + + return device_map + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): + """ + Reads a checkpoint file, returning properly formatted errors if they arise. + """ + try: + file_extension = os.path.basename(checkpoint_file).split(".")[-1] + if file_extension == SAFETENSORS_FILE_EXTENSION: + return safetensors.torch.load_file(checkpoint_file, device="cpu") + else: + weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} + return torch.load( + checkpoint_file, + map_location="cpu", + **weights_only_kwarg, + ) + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " + ) + + +def load_model_dict_into_meta( + model, + state_dict: OrderedDict, + device: Optional[Union[str, torch.device]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + model_name_or_path: Optional[str] = None, +) -> List[str]: + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + + accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) + + unexpected_keys = [] + empty_state_dict = model.state_dict() + for param_name, param in state_dict.items(): + if param_name not in empty_state_dict: + unexpected_keys.append(param_name) + continue + + if empty_state_dict[param_name].shape != param.shape: + model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" + raise ValueError( + f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + ) + + if accepts_dtype: + set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) + else: + set_module_tensor_to_device(model, param_name, device, value=param) + return unexpected_keys + + +def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix: str = ""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load) + + return error_msgs diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e8ba1ed65d..2ed5655c84 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -33,7 +33,6 @@ from .. import __version__ from ..utils import ( CONFIG_NAME, FLAX_WEIGHTS_NAME, - SAFETENSORS_FILE_EXTENSION, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, _add_variant, @@ -44,6 +43,12 @@ from ..utils import ( logging, ) from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card +from .model_loading_utils import ( + _determine_device_map, + _load_state_dict_into_model, + load_model_dict_into_meta, + load_state_dict, +) logger = logging.get_logger(__name__) @@ -57,9 +62,6 @@ else: if is_accelerate_available(): import accelerate - from accelerate import infer_auto_device_map - from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device - from accelerate.utils.versions import is_torch_version def get_parameter_device(parameter: torch.nn.Module) -> torch.device: @@ -100,117 +102,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: return first_tuple[1].dtype -# Adapted from `transformers` (see modeling_utils.py) -def _determine_device_map(model: "ModelMixin", device_map, max_memory, torch_dtype): - if isinstance(device_map, str): - no_split_modules = model._get_no_split_modules(device_map) - device_map_kwargs = {"no_split_module_classes": no_split_modules} - - if device_map != "sequential": - max_memory = get_balanced_memory( - model, - dtype=torch_dtype, - low_zero=(device_map == "balanced_low_0"), - max_memory=max_memory, - **device_map_kwargs, - ) - else: - max_memory = get_max_memory(max_memory) - - device_map_kwargs["max_memory"] = max_memory - device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs) - - return device_map - - -def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): - """ - Reads a checkpoint file, returning properly formatted errors if they arise. - """ - try: - file_extension = os.path.basename(checkpoint_file).split(".")[-1] - if file_extension == SAFETENSORS_FILE_EXTENSION: - return safetensors.torch.load_file(checkpoint_file, device="cpu") - else: - weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {} - return torch.load( - checkpoint_file, - map_location="cpu", - **weights_only_kwarg, - ) - except Exception as e: - try: - with open(checkpoint_file) as f: - if f.read().startswith("version"): - raise OSError( - "You seem to have cloned a repository without having git-lfs installed. Please install " - "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " - "you cloned." - ) - else: - raise ValueError( - f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " - "model. Make sure you have saved the model properly." - ) from e - except (UnicodeDecodeError, ValueError): - raise OSError( - f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. " - ) - - -def load_model_dict_into_meta( - model, - state_dict: OrderedDict, - device: Optional[Union[str, torch.device]] = None, - dtype: Optional[Union[str, torch.dtype]] = None, - model_name_or_path: Optional[str] = None, -) -> List[str]: - device = device or torch.device("cpu") - dtype = dtype or torch.float32 - - accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) - - unexpected_keys = [] - empty_state_dict = model.state_dict() - for param_name, param in state_dict.items(): - if param_name not in empty_state_dict: - unexpected_keys.append(param_name) - continue - - if empty_state_dict[param_name].shape != param.shape: - model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else "" - raise ValueError( - f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." - ) - - if accepts_dtype: - set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype) - else: - set_module_tensor_to_device(model, param_name, device, value=param) - return unexpected_keys - - -def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]: - # Convert old format to new format if needed from a PyTorch state_dict - # copy state_dict so _load_from_state_dict can modify it - state_dict = state_dict.copy() - error_msgs = [] - - # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants - # so we need to apply the function recursively. - def load(module: torch.nn.Module, prefix: str = ""): - args = (state_dict, prefix, {}, True, [], [], error_msgs) - module._load_from_state_dict(*args) - - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + ".") - - load(model_to_load) - - return error_msgs - - class ModelMixin(torch.nn.Module, PushToHubMixin): r""" Base class for all models.