mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Speed up the peft lora unload (#5741)
* Update peft_utils.py * fix bug * make the util backwards compatible. Co-Authored-By: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fix import issue * refactor the backward compatibilty condition * rename the conditional variable * address comments Co-Authored-By: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * address comment --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
c6f90daea6
commit
6f1435332b
@@ -23,55 +23,77 @@ from packaging import version
|
||||
from .import_utils import is_peft_available, is_torch_available
|
||||
|
||||
|
||||
def recurse_remove_peft_layers(model):
|
||||
if is_torch_available():
|
||||
import torch
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def recurse_remove_peft_layers(model):
|
||||
r"""
|
||||
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
|
||||
"""
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
## compound module, go inside it
|
||||
recurse_remove_peft_layers(module)
|
||||
has_base_layer_pattern = False
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
has_base_layer_pattern = hasattr(module, "base_layer")
|
||||
break
|
||||
|
||||
module_replaced = False
|
||||
if has_base_layer_pattern:
|
||||
from peft.utils import _get_submodules
|
||||
|
||||
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
||||
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
|
||||
module.weight.device
|
||||
)
|
||||
new_module.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_module.bias = module.bias
|
||||
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
|
||||
for key in key_list:
|
||||
try:
|
||||
parent, target, target_name = _get_submodules(model, key)
|
||||
except AttributeError:
|
||||
continue
|
||||
if hasattr(target, "base_layer"):
|
||||
setattr(parent, target_name, target.get_base_layer())
|
||||
else:
|
||||
# This is for backwards compatibility with PEFT <= 0.6.2.
|
||||
# TODO can be removed once that PEFT version is no longer supported.
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
module_replaced = True
|
||||
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
|
||||
new_module = torch.nn.Conv2d(
|
||||
module.in_channels,
|
||||
module.out_channels,
|
||||
module.kernel_size,
|
||||
module.stride,
|
||||
module.padding,
|
||||
module.dilation,
|
||||
module.groups,
|
||||
).to(module.weight.device)
|
||||
for name, module in model.named_children():
|
||||
if len(list(module.children())) > 0:
|
||||
## compound module, go inside it
|
||||
recurse_remove_peft_layers(module)
|
||||
|
||||
new_module.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_module.bias = module.bias
|
||||
module_replaced = False
|
||||
|
||||
module_replaced = True
|
||||
if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
|
||||
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
|
||||
module.weight.device
|
||||
)
|
||||
new_module.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_module.bias = module.bias
|
||||
|
||||
if module_replaced:
|
||||
setattr(model, name, new_module)
|
||||
del module
|
||||
module_replaced = True
|
||||
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
|
||||
new_module = torch.nn.Conv2d(
|
||||
module.in_channels,
|
||||
module.out_channels,
|
||||
module.kernel_size,
|
||||
module.stride,
|
||||
module.padding,
|
||||
module.dilation,
|
||||
module.groups,
|
||||
).to(module.weight.device)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
new_module.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_module.bias = module.bias
|
||||
|
||||
module_replaced = True
|
||||
|
||||
if module_replaced:
|
||||
setattr(model, name, new_module)
|
||||
del module
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
return model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user