1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Speed up the peft lora unload (#5741)

* Update peft_utils.py

* fix bug

* make the util backwards compatible.

Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* fix import issue

* refactor the backward compatibilty condition

* rename the conditional variable

* address comments

Co-Authored-By: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* address comment

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
Sourab Mangrulkar
2023-11-17 23:45:44 +05:30
committed by GitHub
parent c6f90daea6
commit 6f1435332b

View File

@@ -23,55 +23,77 @@ from packaging import version
from .import_utils import is_peft_available, is_torch_available
def recurse_remove_peft_layers(model):
if is_torch_available():
import torch
if is_torch_available():
import torch
def recurse_remove_peft_layers(model):
r"""
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
"""
from peft.tuners.lora import LoraLayer
from peft.tuners.tuners_utils import BaseTunerLayer
for name, module in model.named_children():
if len(list(module.children())) > 0:
## compound module, go inside it
recurse_remove_peft_layers(module)
has_base_layer_pattern = False
for module in model.modules():
if isinstance(module, BaseTunerLayer):
has_base_layer_pattern = hasattr(module, "base_layer")
break
module_replaced = False
if has_base_layer_pattern:
from peft.utils import _get_submodules
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
module.weight.device
)
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
for key in key_list:
try:
parent, target, target_name = _get_submodules(model, key)
except AttributeError:
continue
if hasattr(target, "base_layer"):
setattr(parent, target_name, target.get_base_layer())
else:
# This is for backwards compatibility with PEFT <= 0.6.2.
# TODO can be removed once that PEFT version is no longer supported.
from peft.tuners.lora import LoraLayer
module_replaced = True
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
new_module = torch.nn.Conv2d(
module.in_channels,
module.out_channels,
module.kernel_size,
module.stride,
module.padding,
module.dilation,
module.groups,
).to(module.weight.device)
for name, module in model.named_children():
if len(list(module.children())) > 0:
## compound module, go inside it
recurse_remove_peft_layers(module)
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
module_replaced = False
module_replaced = True
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
module.weight.device
)
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
if module_replaced:
setattr(model, name, new_module)
del module
module_replaced = True
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
new_module = torch.nn.Conv2d(
module.in_channels,
module.out_channels,
module.kernel_size,
module.stride,
module.padding,
module.dilation,
module.groups,
).to(module.weight.device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias
module_replaced = True
if module_replaced:
setattr(model, name, new_module)
del module
if torch.cuda.is_available():
torch.cuda.empty_cache()
return model