mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
PEFT Integration for Text Encoder to handle multiple alphas/ranks, disable/enable adapters and support for multiple adapters (#5147)
* more fixes * up * up * style * add in setup * oops * more changes * v1 rzfactor CI * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * few todos * protect torch import * style * fix fuse text encoder * Update src/diffusers/loaders.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * replace with `recurse_replace_peft_layers` * keep old modules for BC * adjustments on `adjust_lora_scale_text_encoder` * nit * move tests * add conversion utils * remove unneeded methods * use class method instead * oops * use `base_version` * fix examples * fix CI * fix weird error with python 3.8 * fix * better fix * style * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * add comment * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * conv2d support for recurse remove * added docstrings * more docstring * add deprecate * revert * try to fix merge conflicts * peft integration features for text encoder 1. support multiple rank/alpha values 2. support multiple active adapters 3. support disabling and enabling adapters * fix bug * fix code quality * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * fix bugs * Apply suggestions from code review Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * address comments Co-Authored-By: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> Co-Authored-By: Patrick von Platen <patrick.v.platen@gmail.com> * fix code quality * address comments * address comments * Apply suggestions from code review * find and replace --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
940f9410cb
commit
02247d9ce1
@@ -35,18 +35,23 @@ from .utils import (
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_omegaconf_available,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
@@ -1100,7 +1105,9 @@ class LoraLoaderMixin:
|
||||
num_fused_loras = 0
|
||||
use_peft_backend = USE_PEFT_BACKEND
|
||||
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
@@ -1120,6 +1127,9 @@ class LoraLoaderMixin:
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
@@ -1143,6 +1153,7 @@ class LoraLoaderMixin:
|
||||
text_encoder=self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
)
|
||||
|
||||
@@ -1500,6 +1511,7 @@ class LoraLoaderMixin:
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
low_cpu_mem_usage=None,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
):
|
||||
"""
|
||||
@@ -1523,6 +1535,9 @@ class LoraLoaderMixin:
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
@@ -1584,19 +1599,22 @@ class LoraLoaderMixin:
|
||||
if cls.use_peft_backend:
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_rank = list(rank.values())[0]
|
||||
# By definition, the scale should be alpha divided by rank.
|
||||
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
|
||||
alpha = lora_scale * lora_rank
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict)
|
||||
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
if patch_mlp:
|
||||
target_modules += ["fc1", "fc2"]
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873
|
||||
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
|
||||
# inject LoRA layers and load the state dict
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
@@ -2178,6 +2196,81 @@ class LoraLoaderMixin:
|
||||
|
||||
self.num_fused_loras -= 1
|
||||
|
||||
def set_adapter_for_text_encoder(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
text_encoder: Optional[PreTrainedModel] = None,
|
||||
text_encoder_weights: List[float] = None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the text encoder.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
text_encoder_weights (`List[float]`, *optional*):
|
||||
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
||||
"""
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
if weights is None:
|
||||
weights = [1.0] * len(adapter_names)
|
||||
elif isinstance(weights, float):
|
||||
weights = [weights]
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
|
||||
)
|
||||
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
||||
|
||||
def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
|
||||
"""
|
||||
Disables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
|
||||
`text_encoder` attribute.
|
||||
"""
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(text_encoder, enabled=False)
|
||||
|
||||
def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
|
||||
"""
|
||||
Enables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
if not self.use_peft_backend:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(self.text_encoder, enabled=True)
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
|
||||
@@ -19,7 +19,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
from ..utils import logging
|
||||
from ..utils import logging, scale_lora_layers
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -27,11 +27,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
|
||||
if use_peft_backend:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
module.scaling[module.active_adapter] = lora_scale
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
else:
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
|
||||
@@ -84,7 +84,14 @@ from .import_utils import (
|
||||
from .loading_utils import load_image
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import recurse_remove_peft_layers
|
||||
from .peft_utils import (
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
|
||||
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft
|
||||
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
"""
|
||||
PEFT utilities: Utilities related to peft library
|
||||
"""
|
||||
import collections
|
||||
|
||||
from .import_utils import is_torch_available
|
||||
|
||||
|
||||
@@ -68,3 +70,98 @@ def recurse_remove_peft_layers(model):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def scale_lora_layers(model, weight):
|
||||
"""
|
||||
Adjust the weightage given to the LoRA layers of the model.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
The model to scale.
|
||||
weight (`float`):
|
||||
The weight to be given to the LoRA layers.
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.scale_layer(weight)
|
||||
|
||||
|
||||
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict):
|
||||
rank_pattern = {}
|
||||
alpha_pattern = {}
|
||||
r = lora_alpha = list(rank_dict.values())[0]
|
||||
if len(set(rank_dict.values())) > 1:
|
||||
# get the rank occuring the most number of times
|
||||
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
||||
|
||||
# for modules with rank different from the most occuring rank, add it to the `rank_pattern`
|
||||
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
|
||||
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
|
||||
|
||||
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
|
||||
# get the alpha occuring the most number of times
|
||||
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
||||
|
||||
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
|
||||
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
||||
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
|
||||
|
||||
# layer names without the Diffusers specific
|
||||
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
|
||||
|
||||
lora_config_kwargs = {
|
||||
"r": r,
|
||||
"lora_alpha": lora_alpha,
|
||||
"rank_pattern": rank_pattern,
|
||||
"alpha_pattern": alpha_pattern,
|
||||
"target_modules": target_modules,
|
||||
}
|
||||
return lora_config_kwargs
|
||||
|
||||
|
||||
def get_adapter_name(model):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return f"default_{len(module.r)}"
|
||||
return "default_0"
|
||||
|
||||
|
||||
def set_adapter_layers(model, enabled=True):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
# The recent version of PEFT needs to call `enable_adapters` instead
|
||||
if hasattr(module, "enable_adapters"):
|
||||
module.enable_adapters(enabled=False)
|
||||
else:
|
||||
module.disable_adapters = True
|
||||
|
||||
|
||||
def set_weights_and_activate_adapters(model, adapter_names, weights):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
# iterate over each adapter, make it active and set the corresponding scaling weight
|
||||
for adapter_name, weight in zip(adapter_names, weights):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
# For backward compatbility with previous PEFT versions
|
||||
if hasattr(module, "set_adapter"):
|
||||
module.set_adapter(adapter_name)
|
||||
else:
|
||||
module.active_adapter = adapter_name
|
||||
module.scale_layer(weight)
|
||||
|
||||
# set multiple active adapters
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
# For backward compatbility with previous PEFT versions
|
||||
if hasattr(module, "set_adapter"):
|
||||
module.set_adapter(adapter_names)
|
||||
else:
|
||||
module.active_adapter = adapter_names
|
||||
|
||||
Reference in New Issue
Block a user