mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Lora] Speed up lora loading (#4994)
* speed up lora loading * Apply suggestions from code review * up * up * Fix more * Correct more * Apply suggestions from code review * up * Fix more * Fix more - * up * up
This commit is contained in:
committed by
GitHub
parent
f64d52dbca
commit
37cb819df5
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
@@ -27,6 +26,7 @@ import torch
|
||||
from huggingface_hub import hf_hub_download, model_info
|
||||
from torch import nn
|
||||
|
||||
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
@@ -46,7 +46,6 @@ if is_transformers_available():
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -137,7 +136,6 @@ class PatchedLoraProjection(nn.Module):
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, input):
|
||||
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
|
||||
if self.lora_scale is None:
|
||||
self.lora_scale = 1.0
|
||||
if self.lora_linear_layer is None:
|
||||
@@ -274,6 +272,11 @@ class UNet2DConditionLoadersMixin:
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
@@ -300,6 +303,7 @@ class UNet2DConditionLoadersMixin:
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
@@ -316,6 +320,15 @@ class UNet2DConditionLoadersMixin:
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
@@ -370,6 +383,10 @@ class UNet2DConditionLoadersMixin:
|
||||
# correct keys
|
||||
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
|
||||
|
||||
if network_alphas is not None:
|
||||
network_alphas_keys = list(network_alphas.keys())
|
||||
used_network_alphas_keys = set()
|
||||
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
mapped_network_alphas = {}
|
||||
|
||||
@@ -381,13 +398,13 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
|
||||
if network_alphas is not None:
|
||||
network_alphas_ = copy.deepcopy(network_alphas)
|
||||
for k in network_alphas_:
|
||||
for k in network_alphas_keys:
|
||||
if k.replace(".alpha", "") in key:
|
||||
mapped_network_alphas.update({attn_processor_key: network_alphas.pop(k)})
|
||||
mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
|
||||
used_network_alphas_keys.add(k)
|
||||
|
||||
if not is_network_alphas_none:
|
||||
if len(network_alphas) > 0:
|
||||
if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
|
||||
raise ValueError(
|
||||
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
|
||||
)
|
||||
@@ -411,29 +428,38 @@ class UNet2DConditionLoadersMixin:
|
||||
out_features = attn_processor.out_channels
|
||||
kernel_size = attn_processor.kernel_size
|
||||
|
||||
lora = LoRAConv2dLayer(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
rank=rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=attn_processor.stride,
|
||||
padding=attn_processor.padding,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
lora = LoRAConv2dLayer(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
rank=rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=attn_processor.stride,
|
||||
padding=attn_processor.padding,
|
||||
network_alpha=mapped_network_alphas.get(key),
|
||||
)
|
||||
elif isinstance(attn_processor, LoRACompatibleLinear):
|
||||
lora = LoRALinearLayer(
|
||||
attn_processor.in_features,
|
||||
attn_processor.out_features,
|
||||
rank,
|
||||
mapped_network_alphas.get(key),
|
||||
)
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
lora = LoRALinearLayer(
|
||||
attn_processor.in_features,
|
||||
attn_processor.out_features,
|
||||
rank,
|
||||
mapped_network_alphas.get(key),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
|
||||
|
||||
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
|
||||
lora.load_state_dict(value_dict)
|
||||
lora_layers_list.append((attn_processor, lora))
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(value_dict.values())).device
|
||||
dtype = next(iter(value_dict.values())).dtype
|
||||
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
lora.load_state_dict(value_dict)
|
||||
elif is_custom_diffusion:
|
||||
attn_processors = {}
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
@@ -470,13 +496,12 @@ class UNet2DConditionLoadersMixin:
|
||||
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
||||
)
|
||||
|
||||
# set correct dtype & device
|
||||
lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list]
|
||||
|
||||
# set lora layers
|
||||
for target_module, lora_layer in lora_layers_list:
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
|
||||
is_new_lora_format = all(
|
||||
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
||||
@@ -999,13 +1024,18 @@ class LoraLoaderMixin:
|
||||
recurive = is_sequential_cpu_offload
|
||||
remove_hook_from_module(component, recurse=recurive)
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
||||
self.load_lora_into_unet(
|
||||
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
@@ -1065,6 +1095,11 @@ class LoraLoaderMixin:
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
@@ -1305,7 +1340,7 @@ class LoraLoaderMixin:
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet):
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
@@ -1318,7 +1353,13 @@ class LoraLoaderMixin:
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
@@ -1343,11 +1384,12 @@ class LoraLoaderMixin:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
warnings.warn(warn_message)
|
||||
|
||||
# load loras into unet
|
||||
unet.load_attn_procs(state_dict, network_alphas=network_alphas)
|
||||
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0):
|
||||
def load_lora_into_text_encoder(
|
||||
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
|
||||
@@ -1364,7 +1406,13 @@ class LoraLoaderMixin:
|
||||
lora_scale (`float`):
|
||||
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
||||
lora layer.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
@@ -1447,6 +1495,7 @@ class LoraLoaderMixin:
|
||||
network_alphas,
|
||||
rank=rank,
|
||||
patch_mlp=patch_mlp,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
# set correct dtype & device
|
||||
@@ -1454,12 +1503,23 @@ class LoraLoaderMixin:
|
||||
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
for k, v in text_encoder_lora_state_dict.items()
|
||||
}
|
||||
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
|
||||
if len(load_state_dict_results.unexpected_keys) != 0:
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
|
||||
)
|
||||
else:
|
||||
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
|
||||
unexpected_keys = load_state_dict_results.unexpected_keys
|
||||
|
||||
if len(unexpected_keys) != 0:
|
||||
raise ValueError(
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
@@ -1492,11 +1552,21 @@ class LoraLoaderMixin:
|
||||
rank: Union[Dict[str, int], int] = 4,
|
||||
dtype=None,
|
||||
patch_mlp=False,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
r"""
|
||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
||||
"""
|
||||
|
||||
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
||||
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
||||
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
with ctx():
|
||||
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
|
||||
|
||||
lora_parameters.extend(model.lora_linear_layer.parameters())
|
||||
return model
|
||||
|
||||
# First, remove any monkey-patch that might have been applied before
|
||||
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
|
||||
|
||||
@@ -1515,45 +1585,18 @@ class LoraLoaderMixin:
|
||||
else:
|
||||
current_rank = rank
|
||||
|
||||
q_linear_layer = (
|
||||
attn_module.q_proj.regular_linear_layer
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection)
|
||||
else attn_module.q_proj
|
||||
attn_module.q_proj = create_patched_linear_lora(
|
||||
attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.q_proj = PatchedLoraProjection(
|
||||
q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
|
||||
attn_module.k_proj = create_patched_linear_lora(
|
||||
attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
|
||||
|
||||
k_linear_layer = (
|
||||
attn_module.k_proj.regular_linear_layer
|
||||
if isinstance(attn_module.k_proj, PatchedLoraProjection)
|
||||
else attn_module.k_proj
|
||||
attn_module.v_proj = create_patched_linear_lora(
|
||||
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
attn_module.k_proj = PatchedLoraProjection(
|
||||
k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
|
||||
attn_module.out_proj = create_patched_linear_lora(
|
||||
attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
|
||||
)
|
||||
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
|
||||
|
||||
v_linear_layer = (
|
||||
attn_module.v_proj.regular_linear_layer
|
||||
if isinstance(attn_module.v_proj, PatchedLoraProjection)
|
||||
else attn_module.v_proj
|
||||
)
|
||||
attn_module.v_proj = PatchedLoraProjection(
|
||||
v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
|
||||
|
||||
out_linear_layer = (
|
||||
attn_module.out_proj.regular_linear_layer
|
||||
if isinstance(attn_module.out_proj, PatchedLoraProjection)
|
||||
else attn_module.out_proj
|
||||
)
|
||||
attn_module.out_proj = PatchedLoraProjection(
|
||||
out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
|
||||
|
||||
if patch_mlp:
|
||||
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
@@ -1563,25 +1606,12 @@ class LoraLoaderMixin:
|
||||
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
|
||||
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
|
||||
|
||||
fc1_linear_layer = (
|
||||
mlp_module.fc1.regular_linear_layer
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection)
|
||||
else mlp_module.fc1
|
||||
mlp_module.fc1 = create_patched_linear_lora(
|
||||
mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
|
||||
)
|
||||
mlp_module.fc1 = PatchedLoraProjection(
|
||||
fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
|
||||
mlp_module.fc2 = create_patched_linear_lora(
|
||||
mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
|
||||
|
||||
fc2_linear_layer = (
|
||||
mlp_module.fc2.regular_linear_layer
|
||||
if isinstance(mlp_module.fc2, PatchedLoraProjection)
|
||||
else mlp_module.fc2
|
||||
)
|
||||
mlp_module.fc2 = PatchedLoraProjection(
|
||||
fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
|
||||
)
|
||||
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
|
||||
|
||||
if is_network_alphas_populated and len(network_alphas) > 0:
|
||||
raise ValueError(
|
||||
@@ -2375,8 +2405,7 @@ class FromOriginalVAEMixin:
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in converted_vae_checkpoint.items():
|
||||
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
|
||||
load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
|
||||
else:
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
|
||||
@@ -128,6 +128,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
)
|
||||
|
||||
|
||||
def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
|
||||
unexpected_keys = []
|
||||
empty_state_dict = model.state_dict()
|
||||
for param_name, param in state_dict.items():
|
||||
if param_name not in empty_state_dict:
|
||||
unexpected_keys.append(param_name)
|
||||
continue
|
||||
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
||||
raise ValueError(
|
||||
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, device, value=param)
|
||||
return unexpected_keys
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict):
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
@@ -624,29 +649,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
||||
" those weights or else make sure your checkpoint file is correct."
|
||||
)
|
||||
unexpected_keys = []
|
||||
|
||||
empty_state_dict = model.state_dict()
|
||||
for param_name, param in state_dict.items():
|
||||
accepts_dtype = "dtype" in set(
|
||||
inspect.signature(set_module_tensor_to_device).parameters.keys()
|
||||
)
|
||||
|
||||
if param_name not in empty_state_dict:
|
||||
unexpected_keys.append(param_name)
|
||||
continue
|
||||
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
)
|
||||
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(
|
||||
model, param_name, param_device, value=param, dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict,
|
||||
device=param_device,
|
||||
dtype=torch_dtype,
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
)
|
||||
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
|
||||
Reference in New Issue
Block a user