1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

peft integration features for text encoder

1. support multiple rank/alpha values
2. support multiple active adapters
3. support disabling and enabling adapters
This commit is contained in:
Sourab Mangrulkar
2023-09-22 12:35:01 +05:30
parent 920333ffaa
commit 0985d17ea9
4 changed files with 170 additions and 17 deletions

View File

@@ -35,12 +35,17 @@ from .utils import (
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
get_adapter_name,
get_rank_and_alpha_pattern,
is_accelerate_available,
is_omegaconf_available,
is_peft_available,
is_transformers_available,
logging,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .utils.import_utils import BACKENDS_MAPPING
@@ -1100,7 +1105,9 @@ class LoraLoaderMixin:
num_fused_loras = 0
use_peft_backend = USE_PEFT_BACKEND
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
@@ -1144,6 +1151,7 @@ class LoraLoaderMixin:
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self,
adapter_name=adapter_name,
)
@classmethod
@@ -1500,6 +1508,7 @@ class LoraLoaderMixin:
lora_scale=1.0,
low_cpu_mem_usage=None,
_pipeline=None,
adapter_name=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1522,6 +1531,9 @@ class LoraLoaderMixin:
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
@@ -1583,18 +1595,30 @@ class LoraLoaderMixin:
if cls.use_peft_backend:
from peft import LoraConfig
lora_rank = list(rank.values())[0]
# By definition, the scale should be alpha divided by rank.
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
alpha = lora_scale * lora_rank
r, lora_alpha, rank_pattern, alpha_pattern, target_modules = get_rank_and_alpha_pattern(
rank, network_alphas, text_encoder_lora_state_dict
)
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
if patch_mlp:
target_modules += ["fc1", "fc2"]
lora_config = LoraConfig(
r=r,
target_modules=target_modules,
lora_alpha=lora_alpha,
rank_pattern=rank_pattern,
alpha_pattern=alpha_pattern,
)
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
# inject LoRA layers and load the state dict
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, lora_weightage=lora_scale)
is_model_cpu_offload = False
is_sequential_cpu_offload = False
@@ -2169,6 +2193,65 @@ class LoraLoaderMixin:
self.num_fused_loras -= 1
def set_adapter(
self,
adapter_names: Union[List[str], str],
unet_weights: List[float] = None,
te_weights: List[float] = None,
te2_weights: List[float] = None,
):
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")
def process_weights(adapter_names, weights):
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights]
if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)
return weights
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
# To Do
# Handle the UNET
# Handle the Text Encoder
te_weights = process_weights(adapter_names, te_weights)
if hasattr(self, "text_encoder"):
set_weights_and_activate_adapters(self.text_encoder, adapter_names, te_weights)
te2_weights = process_weights(adapter_names, te2_weights)
if hasattr(self, "text_encoder_2"):
set_weights_and_activate_adapters(self.text_encoder_2, adapter_names, te2_weights)
def disable_lora(self):
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")
# To Do
# Disbale unet adapters
# Disbale text encoder adapters
if hasattr(self, "text_encoder"):
set_adapter_layers(self.text_encoder, enabled=False)
if hasattr(self, "text_encoder_2"):
set_adapter_layers(self.text_encoder_2, enabled=False)
def enable_lora(self):
if not self.use_peft_backend:
raise ValueError("PEFT backend is required for this method.")
# To Do
# Enable unet adapters
# Enable text encoder adapters
if hasattr(self, "text_encoder"):
set_adapter_layers(self.text_encoder, enabled=True)
if hasattr(self, "text_encoder_2"):
set_adapter_layers(self.text_encoder_2, enabled=True)
class FromSingleFileMixin:
"""

View File

@@ -19,7 +19,7 @@ import torch.nn.functional as F
from torch import nn
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
from ..utils import logging
from ..utils import logging, scale_lora_layers
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -27,11 +27,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
if use_peft_backend:
from peft.tuners.lora import LoraLayer
for module in text_encoder.modules():
if isinstance(module, LoraLayer):
module.scaling[module.active_adapter] = lora_scale
scale_lora_layers(text_encoder, lora_weightage=lora_scale)
else:
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):

View File

@@ -83,7 +83,14 @@ from .import_utils import (
from .loading_utils import load_image
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import recurse_remove_peft_layers
from .peft_utils import (
get_adapter_name,
get_rank_and_alpha_pattern,
recurse_remove_peft_layers,
scale_lora_layers,
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft

View File

@@ -69,3 +69,70 @@ def recurse_remove_peft_layers(model):
torch.cuda.empty_cache()
return model
def scale_lora_layers(model, lora_weightage):
from peft.tuners.tuner_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.scale_layer(lora_weightage)
def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict):
rank_pattern = None
alpha_pattern = None
r = lora_alpha = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
r = max(set(rank_dict.values()), key=list(rank_dict.values()).count)
# for modules with rank different from the most occuring rank, add it to the `rank_pattern`
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
lora_alpha = max(set(network_alpha_dict.values()), key=list(network_alpha_dict.values()).count)
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
# layer names without the Diffusers specific
target_modules = {name.split(".lora")[0] for name in peft_state_dict.keys()}
return r, lora_alpha, rank_pattern, alpha_pattern, target_modules
def get_adapter_name(model):
from peft.tuners.tuner_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return f"default_{len(module.r)}"
return "default_0"
def set_adapter_layers(model, enabled=True):
from peft.tuners.tuner_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.disable_adapters = False if enabled else True
def set_weights_and_activate_adapters(model, adapter_names, weights):
from peft.tuners.tuner_utils import BaseTunerLayer
# iterate over each adapter, make it active and set the corresponding scaling weight
for adapter_name, weight in zip(adapter_names, weights):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_name
module.scale_layer(weight)
# set multiple active adapters
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.active_adapter = adapter_names