mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
peft integration features for text encoder
1. support multiple rank/alpha values 2. support multiple active adapters 3. support disabling and enabling adapters
This commit is contained in:
@@ -35,12 +35,17 @@ from .utils import (
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_rank_and_alpha_pattern,
|
||||
is_accelerate_available,
|
||||
is_omegaconf_available,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
@@ -1100,7 +1105,9 @@ class LoraLoaderMixin:
|
||||
num_fused_loras = 0
|
||||
use_peft_backend = USE_PEFT_BACKEND
|
||||
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
@@ -1144,6 +1151,7 @@ class LoraLoaderMixin:
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
_pipeline=self,
|
||||
adapter_name=adapter_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1500,6 +1508,7 @@ class LoraLoaderMixin:
|
||||
lora_scale=1.0,
|
||||
low_cpu_mem_usage=None,
|
||||
_pipeline=None,
|
||||
adapter_name=None,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -1522,6 +1531,9 @@ class LoraLoaderMixin:
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
@@ -1583,18 +1595,30 @@ class LoraLoaderMixin:
|
||||
if cls.use_peft_backend:
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_rank = list(rank.values())[0]
|
||||
# By definition, the scale should be alpha divided by rank.
|
||||
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
|
||||
alpha = lora_scale * lora_rank
|
||||
r, lora_alpha, rank_pattern, alpha_pattern, target_modules = get_rank_and_alpha_pattern(
|
||||
rank, network_alphas, text_encoder_lora_state_dict
|
||||
)
|
||||
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
if patch_mlp:
|
||||
target_modules += ["fc1", "fc2"]
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=lora_alpha,
|
||||
rank_pattern=rank_pattern,
|
||||
alpha_pattern=alpha_pattern,
|
||||
)
|
||||
|
||||
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
|
||||
# inject LoRA layers and load the state dict
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, lora_weightage=lora_scale)
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
@@ -2169,6 +2193,65 @@ class LoraLoaderMixin:
|
||||
|
||||
self.num_fused_loras -= 1
|
||||
|
||||
def set_adapter(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
unet_weights: List[float] = None,
|
||||
te_weights: List[float] = None,
|
||||
te2_weights: List[float] = None,
|
||||
):
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
if weights is None:
|
||||
weights = [1.0] * len(adapter_names)
|
||||
elif isinstance(weights, float):
|
||||
weights = [weights]
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
|
||||
# To Do
|
||||
# Handle the UNET
|
||||
|
||||
# Handle the Text Encoder
|
||||
te_weights = process_weights(adapter_names, te_weights)
|
||||
if hasattr(self, "text_encoder"):
|
||||
set_weights_and_activate_adapters(self.text_encoder, adapter_names, te_weights)
|
||||
te2_weights = process_weights(adapter_names, te2_weights)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
set_weights_and_activate_adapters(self.text_encoder_2, adapter_names, te2_weights)
|
||||
|
||||
def disable_lora(self):
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
# To Do
|
||||
# Disbale unet adapters
|
||||
|
||||
# Disbale text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
set_adapter_layers(self.text_encoder, enabled=False)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
set_adapter_layers(self.text_encoder_2, enabled=False)
|
||||
|
||||
def enable_lora(self):
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
# To Do
|
||||
# Enable unet adapters
|
||||
|
||||
# Enable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
set_adapter_layers(self.text_encoder, enabled=True)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
set_adapter_layers(self.text_encoder_2, enabled=True)
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
from ..utils import logging
|
||||
from ..utils import logging, scale_lora_layers
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -27,11 +27,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
|
||||
if use_peft_backend:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
module.scaling[module.active_adapter] = lora_scale
|
||||
scale_lora_layers(text_encoder, lora_weightage=lora_scale)
|
||||
else:
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
|
||||
@@ -83,7 +83,14 @@ from .import_utils import (
|
||||
from .loading_utils import load_image
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import recurse_remove_peft_layers
|
||||
from .peft_utils import (
|
||||
get_adapter_name,
|
||||
get_rank_and_alpha_pattern,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
|
||||
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft
|
||||
|
||||
|
||||
@@ -69,3 +69,70 @@ def recurse_remove_peft_layers(model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def scale_lora_layers(model, lora_weightage):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.scale_layer(lora_weightage)
|
||||
|
||||
|
||||
def get_rank_and_alpha_pattern(rank_dict, network_alpha_dict, peft_state_dict):
|
||||
rank_pattern = None
|
||||
alpha_pattern = None
|
||||
r = lora_alpha = list(rank_dict.values())[0]
|
||||
if len(set(rank_dict.values())) > 1:
|
||||
# get the rank occuring the most number of times
|
||||
r = max(set(rank_dict.values()), key=list(rank_dict.values()).count)
|
||||
|
||||
# for modules with rank different from the most occuring rank, add it to the `rank_pattern`
|
||||
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
|
||||
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
|
||||
|
||||
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
|
||||
# get the alpha occuring the most number of times
|
||||
lora_alpha = max(set(network_alpha_dict.values()), key=list(network_alpha_dict.values()).count)
|
||||
|
||||
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
|
||||
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
||||
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
|
||||
|
||||
# layer names without the Diffusers specific
|
||||
target_modules = {name.split(".lora")[0] for name in peft_state_dict.keys()}
|
||||
|
||||
return r, lora_alpha, rank_pattern, alpha_pattern, target_modules
|
||||
|
||||
|
||||
def get_adapter_name(model):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return f"default_{len(module.r)}"
|
||||
return "default_0"
|
||||
|
||||
|
||||
def set_adapter_layers(model, enabled=True):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.disable_adapters = False if enabled else True
|
||||
|
||||
|
||||
def set_weights_and_activate_adapters(model, adapter_names, weights):
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
|
||||
# iterate over each adapter, make it active and set the corresponding scaling weight
|
||||
for adapter_name, weight in zip(adapter_names, weights):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.active_adapter = adapter_name
|
||||
module.scale_layer(weight)
|
||||
|
||||
# set multiple active adapters
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.active_adapter = adapter_names
|
||||
|
||||
Reference in New Issue
Block a user