1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Lora] Speed up lora loading (#4994)

* speed up lora loading

* Apply suggestions from code review

* up

* up

* Fix more

* Correct more

* Apply suggestions from code review

* up

* Fix more

* Fix more -

* up

* up
This commit is contained in:
Patrick von Platen
2023-09-12 17:51:15 +02:00
committed by GitHub
parent f64d52dbca
commit 37cb819df5
2 changed files with 148 additions and 109 deletions

View File

@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import os
import re
import warnings
@@ -27,6 +26,7 @@ import torch
from huggingface_hub import hf_hub_download, model_info
from torch import nn
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
@@ -46,7 +46,6 @@ if is_transformers_available():
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__)
@@ -137,7 +136,6 @@ class PatchedLoraProjection(nn.Module):
self.w_down = None
def forward(self, input):
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
if self.lora_scale is None:
self.lora_scale = 1.0
if self.lora_linear_layer is None:
@@ -274,6 +272,11 @@ class UNet2DConditionLoadersMixin:
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
@@ -300,6 +303,7 @@ class UNet2DConditionLoadersMixin:
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None)
@@ -316,6 +320,15 @@ class UNet2DConditionLoadersMixin:
"framework": "pytorch",
}
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
@@ -370,6 +383,10 @@ class UNet2DConditionLoadersMixin:
# correct keys
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
if network_alphas is not None:
network_alphas_keys = list(network_alphas.keys())
used_network_alphas_keys = set()
lora_grouped_dict = defaultdict(dict)
mapped_network_alphas = {}
@@ -381,13 +398,13 @@ class UNet2DConditionLoadersMixin:
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
if network_alphas is not None:
network_alphas_ = copy.deepcopy(network_alphas)
for k in network_alphas_:
for k in network_alphas_keys:
if k.replace(".alpha", "") in key:
mapped_network_alphas.update({attn_processor_key: network_alphas.pop(k)})
mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
used_network_alphas_keys.add(k)
if not is_network_alphas_none:
if len(network_alphas) > 0:
if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)
@@ -411,29 +428,38 @@ class UNet2DConditionLoadersMixin:
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size
lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
)
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
lora.load_state_dict(value_dict)
lora_layers_list.append((attn_processor, lora))
if low_cpu_mem_usage:
device = next(iter(value_dict.values())).device
dtype = next(iter(value_dict.values())).dtype
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
else:
lora.load_state_dict(value_dict)
elif is_custom_diffusion:
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
@@ -470,13 +496,12 @@ class UNet2DConditionLoadersMixin:
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
)
# set correct dtype & device
lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list]
# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device)
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
@@ -999,13 +1024,18 @@ class LoraLoaderMixin:
recurive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recurive)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
)
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
)
# Offload back.
@@ -1065,6 +1095,11 @@ class LoraLoaderMixin:
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
@@ -1305,7 +1340,7 @@ class LoraLoaderMixin:
return new_state_dict
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet):
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -1318,7 +1353,13 @@ class LoraLoaderMixin:
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
@@ -1343,11 +1384,12 @@ class LoraLoaderMixin:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message)
# load loras into unet
unet.load_attn_procs(state_dict, network_alphas=network_alphas)
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
@classmethod
def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0):
def load_lora_into_text_encoder(
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1364,7 +1406,13 @@ class LoraLoaderMixin:
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
@@ -1447,6 +1495,7 @@ class LoraLoaderMixin:
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
# set correct dtype & device
@@ -1454,12 +1503,23 @@ class LoraLoaderMixin:
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
for k, v in text_encoder_lora_state_dict.items()
}
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
if len(load_state_dict_results.unexpected_keys) != 0:
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
unexpected_keys = load_state_dict_results.unexpected_keys
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
@property
def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
@@ -1492,11 +1552,21 @@ class LoraLoaderMixin:
rank: Union[Dict[str, int], int] = 4,
dtype=None,
patch_mlp=False,
low_cpu_mem_usage=False,
):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
model = PatchedLoraProjection(linear_layer, lora_scale, network_alpha, rank, dtype=dtype)
lora_parameters.extend(model.lora_linear_layer.parameters())
return model
# First, remove any monkey-patch that might have been applied before
cls._remove_text_encoder_monkey_patch_classmethod(text_encoder)
@@ -1515,45 +1585,18 @@ class LoraLoaderMixin:
else:
current_rank = rank
q_linear_layer = (
attn_module.q_proj.regular_linear_layer
if isinstance(attn_module.q_proj, PatchedLoraProjection)
else attn_module.q_proj
attn_module.q_proj = create_patched_linear_lora(
attn_module.q_proj, query_alpha, current_rank, dtype, lora_parameters
)
attn_module.q_proj = PatchedLoraProjection(
q_linear_layer, lora_scale, network_alpha=query_alpha, rank=current_rank, dtype=dtype
attn_module.k_proj = create_patched_linear_lora(
attn_module.k_proj, key_alpha, current_rank, dtype, lora_parameters
)
lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters())
k_linear_layer = (
attn_module.k_proj.regular_linear_layer
if isinstance(attn_module.k_proj, PatchedLoraProjection)
else attn_module.k_proj
attn_module.v_proj = create_patched_linear_lora(
attn_module.v_proj, value_alpha, current_rank, dtype, lora_parameters
)
attn_module.k_proj = PatchedLoraProjection(
k_linear_layer, lora_scale, network_alpha=key_alpha, rank=current_rank, dtype=dtype
attn_module.out_proj = create_patched_linear_lora(
attn_module.out_proj, out_alpha, current_rank, dtype, lora_parameters
)
lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters())
v_linear_layer = (
attn_module.v_proj.regular_linear_layer
if isinstance(attn_module.v_proj, PatchedLoraProjection)
else attn_module.v_proj
)
attn_module.v_proj = PatchedLoraProjection(
v_linear_layer, lora_scale, network_alpha=value_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters())
out_linear_layer = (
attn_module.out_proj.regular_linear_layer
if isinstance(attn_module.out_proj, PatchedLoraProjection)
else attn_module.out_proj
)
attn_module.out_proj = PatchedLoraProjection(
out_linear_layer, lora_scale, network_alpha=out_alpha, rank=current_rank, dtype=dtype
)
lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters())
if patch_mlp:
for name, mlp_module in text_encoder_mlp_modules(text_encoder):
@@ -1563,25 +1606,12 @@ class LoraLoaderMixin:
current_rank_fc1 = rank.pop(f"{name}.fc1.lora_linear_layer.up.weight")
current_rank_fc2 = rank.pop(f"{name}.fc2.lora_linear_layer.up.weight")
fc1_linear_layer = (
mlp_module.fc1.regular_linear_layer
if isinstance(mlp_module.fc1, PatchedLoraProjection)
else mlp_module.fc1
mlp_module.fc1 = create_patched_linear_lora(
mlp_module.fc1, fc1_alpha, current_rank_fc1, dtype, lora_parameters
)
mlp_module.fc1 = PatchedLoraProjection(
fc1_linear_layer, lora_scale, network_alpha=fc1_alpha, rank=current_rank_fc1, dtype=dtype
mlp_module.fc2 = create_patched_linear_lora(
mlp_module.fc2, fc2_alpha, current_rank_fc2, dtype, lora_parameters
)
lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters())
fc2_linear_layer = (
mlp_module.fc2.regular_linear_layer
if isinstance(mlp_module.fc2, PatchedLoraProjection)
else mlp_module.fc2
)
mlp_module.fc2 = PatchedLoraProjection(
fc2_linear_layer, lora_scale, network_alpha=fc2_alpha, rank=current_rank_fc2, dtype=dtype
)
lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters())
if is_network_alphas_populated and len(network_alphas) > 0:
raise ValueError(
@@ -2375,8 +2405,7 @@ class FromOriginalVAEMixin:
vae = AutoencoderKL(**vae_config)
if is_accelerate_available():
for param_name, param in converted_vae_checkpoint.items():
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
load_model_dict_into_meta(vae, converted_vae_checkpoint, device="cpu")
else:
vae.load_state_dict(converted_vae_checkpoint)

View File

@@ -128,6 +128,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
)
def load_model_dict_into_meta(model, state_dict, device=None, dtype=None, model_name_or_path=None):
device = device or torch.device("cpu")
dtype = dtype or torch.float32
unexpected_keys = []
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue
if empty_state_dict[param_name].shape != param.shape:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
return unexpected_keys
def _load_state_dict_into_model(model_to_load, state_dict):
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
@@ -624,29 +649,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
" those weights or else make sure your checkpoint file is correct."
)
unexpected_keys = []
empty_state_dict = model.state_dict()
for param_name, param in state_dict.items():
accepts_dtype = "dtype" in set(
inspect.signature(set_module_tensor_to_device).parameters.keys()
)
if param_name not in empty_state_dict:
unexpected_keys.append(param_name)
continue
if empty_state_dict[param_name].shape != param.shape:
raise ValueError(
f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if accepts_dtype:
set_module_tensor_to_device(
model, param_name, param_device, value=param, dtype=torch_dtype
)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param)
unexpected_keys = load_model_dict_into_meta(
model,
state_dict,
device=param_device,
dtype=torch_dtype,
model_name_or_path=pretrained_model_name_or_path,
)
if cls._keys_to_ignore_on_load_unexpected is not None:
for pat in cls._keys_to_ignore_on_load_unexpected: