mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Implements Blockwise lora (#7352)
* Initial commit * Implemented block lora - implemented block lora - updated docs - added tests * Finishing up * Reverted unrelated changes made by make style * Fixed typo * Fixed bug + Made text_encoder_2 scalable * Integrated some review feedback * Incorporated review feedback * Fix tests * Made every module configurable * Adapter to new lora test structure * Final cleanup * Some more final fixes - Included examples in `using_peft_for_inference.md` - Added hint that only attns are scaled - Removed NoneTypes - Added test to check mismatching lens of adapter names / weights raise error * Update using_peft_for_inference.md * Update using_peft_for_inference.md * Make style, quality, fix-copies * Updated tutorial;Warning if scale/adapter mismatch * floats are forwarded as-is; changed tutorial scale * make style, quality, fix-copies * Fixed typo in tutorial * Moved some warnings into `lora_loader_utils.py` * Moved scale/lora mismatch warnings back * Integrated final review suggestions * Empty commit to trigger CI * Reverted emoty commit to trigger CI --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -133,6 +133,62 @@ image
|
||||
|
||||

|
||||
|
||||
### Customize adapters strength
|
||||
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].
|
||||
|
||||
For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
|
||||
```python
|
||||
pipe.enable_lora() # enable lora again, after we disabled it above
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
adapter_weight_scales = { "unet": { "down": 1, "mid": 0, "up": 0} }
|
||||
pipe.set_adapters("pixel", adapter_weight_scales)
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
Let's see how turning off the `down` part and turning on the `mid` and `up` part respectively changes the image.
|
||||
```python
|
||||
adapter_weight_scales = { "unet": { "down": 0, "mid": 1, "up": 0} }
|
||||
pipe.set_adapters("pixel", adapter_weight_scales)
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
```python
|
||||
adapter_weight_scales = { "unet": { "down": 0, "mid": 0, "up": 1} }
|
||||
pipe.set_adapters("pixel", adapter_weight_scales)
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
Looks cool!
|
||||
|
||||
This is a really powerful feature. You can use it to control the adapter strengths down to per-transformer level. And you can even use it for multiple adapters.
|
||||
```python
|
||||
adapter_weight_scales_toy = 0.5
|
||||
adapter_weight_scales_pixel = {
|
||||
"unet": {
|
||||
"down": 0.9, # all transformers in the down-part will use scale 0.9
|
||||
# "mid" # because, in this example, "mid" is not given, all transformers in the mid part will use the default scale 1.0
|
||||
"up": {
|
||||
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
|
||||
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
|
||||
}
|
||||
}
|
||||
}
|
||||
pipe.set_adapters(["toy", "pixel"], [adapter_weight_scales_toy, adapter_weight_scales_pixel])
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
## Manage active adapters
|
||||
|
||||
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
|
||||
|
||||
@@ -153,18 +153,43 @@ image
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
|
||||
</div>
|
||||
|
||||
<Tip>
|
||||
|
||||
For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
|
||||
|
||||
</Tip>
|
||||
|
||||
To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
|
||||
|
||||
```py
|
||||
pipeline.unload_lora_weights()
|
||||
```
|
||||
|
||||
### Adjust LoRA weight scale
|
||||
|
||||
For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
|
||||
|
||||
For more granular control on the amount of LoRA weights used per layer, you can use [`~loaders.LoraLoaderMixin.set_adapters`] and pass a dictionary specifying by how much to scale the weights in each layer by.
|
||||
```python
|
||||
pipe = ... # create pipeline
|
||||
pipe.load_lora_weights(..., adapter_name="my_adapter")
|
||||
scales = {
|
||||
"text_encoder": 0.5,
|
||||
"text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder
|
||||
"unet": {
|
||||
"down": 0.9, # all transformers in the down-part will use scale 0.9
|
||||
# "mid" # in this example "mid" is not given, therefore all transformers in the mid part will use the default scale 1.0
|
||||
"up": {
|
||||
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
|
||||
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
|
||||
}
|
||||
}
|
||||
}
|
||||
pipe.set_adapters("my_adapter", scales)
|
||||
```
|
||||
|
||||
This also works with multiple adapters - see [this guide](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#customize-adapters-strength) for how to do it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Currently, [`~loaders.LoraLoaderMixin.set_adapters`] only supports scaling attention weights. If a LoRA has other parts (e.g., resnets or down-/upsamplers), they will keep a scale of 1.0.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Kohya and TheLastBen
|
||||
|
||||
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -985,7 +986,7 @@ class LoraLoaderMixin:
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
text_encoder_weights: List[float] = None,
|
||||
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the text encoder.
|
||||
@@ -1003,15 +1004,20 @@ class LoraLoaderMixin:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
if weights is None:
|
||||
weights = [1.0] * len(adapter_names)
|
||||
elif isinstance(weights, float):
|
||||
weights = [weights]
|
||||
# Expand weights into a list, one entry per adapter
|
||||
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
|
||||
if not isinstance(weights, list):
|
||||
weights = [weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
|
||||
# Set None values to default of 1.0
|
||||
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
|
||||
weights = [w if w is not None else 1.0 for w in weights]
|
||||
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
@@ -1059,17 +1065,77 @@ class LoraLoaderMixin:
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[List[float]] = None,
|
||||
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
||||
):
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
|
||||
adapter_weights = copy.deepcopy(adapter_weights)
|
||||
|
||||
# Expand weights into a list, one entry per adapter
|
||||
if not isinstance(adapter_weights, list):
|
||||
adapter_weights = [adapter_weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(adapter_weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
||||
)
|
||||
|
||||
# Decompose weights into weights for unet, text_encoder and text_encoder_2
|
||||
unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], []
|
||||
|
||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||
all_adapters = {
|
||||
adapter for adapters in list_adapters.values() for adapter in adapters
|
||||
} # eg ["adapter1", "adapter2"]
|
||||
invert_list_adapters = {
|
||||
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
||||
for adapter in all_adapters
|
||||
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
|
||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||
if isinstance(weights, dict):
|
||||
unet_lora_weight = weights.pop("unet", None)
|
||||
text_encoder_lora_weight = weights.pop("text_encoder", None)
|
||||
text_encoder_2_lora_weight = weights.pop("text_encoder_2", None)
|
||||
|
||||
if len(weights) > 0:
|
||||
raise ValueError(
|
||||
f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}."
|
||||
)
|
||||
|
||||
if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"):
|
||||
logger.warning(
|
||||
"Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2."
|
||||
)
|
||||
|
||||
# warn if adapter doesn't have parts specified by adapter_weights
|
||||
for part_weight, part_name in zip(
|
||||
[unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight],
|
||||
["uent", "text_encoder", "text_encoder_2"],
|
||||
):
|
||||
if part_weight is not None and part_name not in invert_list_adapters[adapter_name]:
|
||||
logger.warning(
|
||||
f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
|
||||
)
|
||||
|
||||
else:
|
||||
unet_lora_weight = weights
|
||||
text_encoder_lora_weight = weights
|
||||
text_encoder_2_lora_weight = weights
|
||||
|
||||
unet_lora_weights.append(unet_lora_weight)
|
||||
text_encoder_lora_weights.append(text_encoder_lora_weight)
|
||||
text_encoder_2_lora_weights.append(text_encoder_2_lora_weight)
|
||||
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
# Handle the UNET
|
||||
unet.set_adapters(adapter_names, adapter_weights)
|
||||
unet.set_adapters(adapter_names, unet_lora_weights)
|
||||
|
||||
# Handle the Text Encoder
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights)
|
||||
|
||||
def disable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
|
||||
@@ -47,6 +47,7 @@ from .single_file_utils import (
|
||||
infer_stable_cascade_single_file_config,
|
||||
load_single_file_model_checkpoint,
|
||||
)
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
@@ -564,7 +565,7 @@ class UNet2DConditionLoadersMixin:
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
weights: Optional[Union[List[float], float]] = None,
|
||||
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
|
||||
):
|
||||
"""
|
||||
Set the currently active adapters for use in the UNet.
|
||||
@@ -597,9 +598,9 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
|
||||
if weights is None:
|
||||
weights = [1.0] * len(adapter_names)
|
||||
elif isinstance(weights, float):
|
||||
# Expand weights into a list, one entry per adapter
|
||||
# examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
|
||||
if not isinstance(weights, list):
|
||||
weights = [weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
@@ -607,6 +608,13 @@ class UNet2DConditionLoadersMixin:
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
|
||||
)
|
||||
|
||||
# Set None values to default of 1.0
|
||||
# e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
|
||||
weights = [w if w is not None else 1.0 for w in weights]
|
||||
|
||||
# e.g. [{...}, 7] -> [{expanded dict...}, 7]
|
||||
weights = _maybe_expand_lora_scales(self, weights)
|
||||
|
||||
set_weights_and_activate_adapters(self, adapter_names, weights)
|
||||
|
||||
def disable_lora(self):
|
||||
|
||||
154
src/diffusers/loaders/unet_loader_utils.py
Normal file
154
src/diffusers/loaders/unet_loader_utils.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# import here to avoid circular imports
|
||||
from ..models import UNet2DConditionModel
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _translate_into_actual_layer_name(name):
|
||||
"""Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
|
||||
if name == "mid":
|
||||
return "mid_block.attentions.0"
|
||||
|
||||
updown, block, attn = name.split(".")
|
||||
|
||||
updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
|
||||
block = block.replace("block_", "")
|
||||
attn = "attentions." + attn
|
||||
|
||||
return ".".join((updown, block, attn))
|
||||
|
||||
|
||||
def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]]):
|
||||
blocks_with_transformer = {
|
||||
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
|
||||
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
|
||||
}
|
||||
transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}
|
||||
|
||||
expanded_weight_scales = [
|
||||
_maybe_expand_lora_scales_for_one_adapter(
|
||||
weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict()
|
||||
)
|
||||
for weight_for_adapter in weight_scales
|
||||
]
|
||||
|
||||
return expanded_weight_scales
|
||||
|
||||
|
||||
def _maybe_expand_lora_scales_for_one_adapter(
|
||||
scales: Union[float, Dict],
|
||||
blocks_with_transformer: Dict[str, int],
|
||||
transformer_per_block: Dict[str, int],
|
||||
state_dict: None,
|
||||
):
|
||||
"""
|
||||
Expands the inputs into a more granular dictionary. See the example below for more details.
|
||||
|
||||
Parameters:
|
||||
scales (`Union[float, Dict]`):
|
||||
Scales dict to expand.
|
||||
blocks_with_transformer (`Dict[str, int]`):
|
||||
Dict with keys 'up' and 'down', showing which blocks have transformer layers
|
||||
transformer_per_block (`Dict[str, int]`):
|
||||
Dict with keys 'up' and 'down', showing how many transformer layers each block has
|
||||
|
||||
E.g. turns
|
||||
```python
|
||||
scales = {
|
||||
'down': 2,
|
||||
'mid': 3,
|
||||
'up': {
|
||||
'block_0': 4,
|
||||
'block_1': [5, 6, 7]
|
||||
}
|
||||
}
|
||||
blocks_with_transformer = {
|
||||
'down': [1,2],
|
||||
'up': [0,1]
|
||||
}
|
||||
transformer_per_block = {
|
||||
'down': 2,
|
||||
'up': 3
|
||||
}
|
||||
```
|
||||
into
|
||||
```python
|
||||
{
|
||||
'down.block_1.0': 2,
|
||||
'down.block_1.1': 2,
|
||||
'down.block_2.0': 2,
|
||||
'down.block_2.1': 2,
|
||||
'mid': 3,
|
||||
'up.block_0.0': 4,
|
||||
'up.block_0.1': 4,
|
||||
'up.block_0.2': 4,
|
||||
'up.block_1.0': 5,
|
||||
'up.block_1.1': 6,
|
||||
'up.block_1.2': 7,
|
||||
}
|
||||
```
|
||||
"""
|
||||
if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
|
||||
raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")
|
||||
|
||||
if sorted(transformer_per_block.keys()) != ["down", "up"]:
|
||||
raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")
|
||||
|
||||
if not isinstance(scales, dict):
|
||||
# don't expand if scales is a single number
|
||||
return scales
|
||||
|
||||
scales = copy.deepcopy(scales)
|
||||
|
||||
if "mid" not in scales:
|
||||
scales["mid"] = 1
|
||||
|
||||
for updown in ["up", "down"]:
|
||||
if updown not in scales:
|
||||
scales[updown] = 1
|
||||
|
||||
# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
|
||||
if not isinstance(scales[updown], dict):
|
||||
scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]}
|
||||
|
||||
# eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}}
|
||||
for i in blocks_with_transformer[updown]:
|
||||
block = f"block_{i}"
|
||||
if not isinstance(scales[updown][block], list):
|
||||
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
|
||||
|
||||
# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
|
||||
for i in blocks_with_transformer[updown]:
|
||||
block = f"block_{i}"
|
||||
for tf_idx, value in enumerate(scales[updown][block]):
|
||||
scales[f"{updown}.{block}.{tf_idx}"] = value
|
||||
|
||||
del scales[updown]
|
||||
|
||||
for layer in scales.keys():
|
||||
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
|
||||
raise ValueError(
|
||||
f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
|
||||
)
|
||||
|
||||
return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}
|
||||
@@ -230,16 +230,26 @@ def delete_adapter_layers(model, adapter_name):
|
||||
def set_weights_and_activate_adapters(model, adapter_names, weights):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
def get_module_weight(weight_for_adapter, module_name):
|
||||
if not isinstance(weight_for_adapter, dict):
|
||||
# If weight_for_adapter is a single number, always return it.
|
||||
return weight_for_adapter
|
||||
|
||||
for layer_name, weight_ in weight_for_adapter.items():
|
||||
if layer_name in module_name:
|
||||
return weight_
|
||||
raise RuntimeError(f"No LoRA weight found for module {module_name}.")
|
||||
|
||||
# iterate over each adapter, make it active and set the corresponding scaling weight
|
||||
for adapter_name, weight in zip(adapter_names, weights):
|
||||
for module in model.modules():
|
||||
for module_name, module in model.named_modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
# For backward compatbility with previous PEFT versions
|
||||
if hasattr(module, "set_adapter"):
|
||||
module.set_adapter(adapter_name)
|
||||
else:
|
||||
module.active_adapter = adapter_name
|
||||
module.set_scale(adapter_name, weight)
|
||||
module.set_scale(adapter_name, get_module_weight(weight, module_name))
|
||||
|
||||
# set multiple active adapters
|
||||
for module in model.modules():
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from itertools import product
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -762,6 +763,218 @@ class PeftLoraLoaderMixinTests:
|
||||
"output with no lora and output with lora disabled should give same results",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_unet_block_scale(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, attaches
|
||||
one adapter and set differnt weights for different blocks (i.e. block lora)
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
|
||||
pipe.set_adapters("adapter-1", weights_1)
|
||||
output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
weights_2 = {"unet": {"up": 5}}
|
||||
pipe.set_adapters("adapter-1", weights_2)
|
||||
output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
|
||||
"LoRA weights 1 and 2 should give different results",
|
||||
)
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
|
||||
"No adapter and LoRA weights 1 should give different results",
|
||||
)
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
|
||||
"No adapter and LoRA weights 2 should give different results",
|
||||
)
|
||||
|
||||
pipe.disable_lora()
|
||||
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
|
||||
"output with no lora and output with lora disabled should give same results",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_unet_multi_adapter_block_lora(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, attaches
|
||||
multiple adapters and set differnt weights for different blocks (i.e. block lora)
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
|
||||
scales_2 = {"unet": {"down": 5, "mid": 5}}
|
||||
pipe.set_adapters("adapter-1", scales_1)
|
||||
|
||||
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters("adapter-2", scales_2)
|
||||
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
|
||||
|
||||
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
# Fuse and unfuse should lead to the same results
|
||||
self.assertFalse(
|
||||
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
|
||||
"Adapter 1 and 2 should give different results",
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
|
||||
"Adapter 1 and mixed adapters should give different results",
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
|
||||
"Adapter 2 and mixed adapters should give different results",
|
||||
)
|
||||
|
||||
pipe.disable_lora()
|
||||
|
||||
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
|
||||
"output with no lora and output with lora disabled should give same results",
|
||||
)
|
||||
|
||||
# a mismatching number of adapter_names and adapter_weights should raise an error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
|
||||
|
||||
def test_simple_inference_with_text_unet_block_scale_for_all_dict_options(self):
|
||||
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
|
||||
|
||||
def updown_options(blocks_with_tf, layers_per_block, value):
|
||||
"""
|
||||
Generate every possible combination for how a lora weight dict for the up/down part can be.
|
||||
E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
|
||||
"""
|
||||
num_val = value
|
||||
list_val = [value] * layers_per_block
|
||||
|
||||
node_opts = [None, num_val, list_val]
|
||||
node_opts_foreach_block = [node_opts] * len(blocks_with_tf)
|
||||
|
||||
updown_opts = [num_val]
|
||||
for nodes in product(*node_opts_foreach_block):
|
||||
if all(n is None for n in nodes):
|
||||
continue
|
||||
opt = {}
|
||||
for b, n in zip(blocks_with_tf, nodes):
|
||||
if n is not None:
|
||||
opt["block_" + str(b)] = n
|
||||
updown_opts.append(opt)
|
||||
return updown_opts
|
||||
|
||||
def all_possible_dict_opts(unet, value):
|
||||
"""
|
||||
Generate every possible combination for how a lora weight dict can be.
|
||||
E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
|
||||
"""
|
||||
|
||||
down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
|
||||
up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")]
|
||||
|
||||
layers_per_block = unet.config.layers_per_block
|
||||
|
||||
text_encoder_opts = [None, value]
|
||||
text_encoder_2_opts = [None, value]
|
||||
mid_opts = [None, value]
|
||||
down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
|
||||
up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)
|
||||
|
||||
opts = []
|
||||
|
||||
for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
|
||||
if all(o is None for o in (t1, t2, d, m, u)):
|
||||
continue
|
||||
opt = {}
|
||||
if t1 is not None:
|
||||
opt["text_encoder"] = t1
|
||||
if t2 is not None:
|
||||
opt["text_encoder_2"] = t2
|
||||
if all(o is None for o in (d, m, u)):
|
||||
# no unet scaling
|
||||
continue
|
||||
opt["unet"] = {}
|
||||
if d is not None:
|
||||
opt["unet"]["down"] = d
|
||||
if m is not None:
|
||||
opt["unet"]["mid"] = m
|
||||
if u is not None:
|
||||
opt["unet"]["up"] = u
|
||||
opts.append(opt)
|
||||
|
||||
return opts
|
||||
|
||||
components, text_lora_config, unet_lora_config = self.get_dummy_components(self.scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
|
||||
for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
|
||||
# test if lora block scales can be set with this scale_dict
|
||||
if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
|
||||
del scale_dict["text_encoder_2"]
|
||||
|
||||
pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error
|
||||
|
||||
def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, attaches
|
||||
|
||||
Reference in New Issue
Block a user