1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Implements Blockwise lora (#7352)

* Initial commit

* Implemented block lora

- implemented block lora
- updated docs
- added tests

* Finishing up

* Reverted unrelated changes made by make style

* Fixed typo

* Fixed bug + Made text_encoder_2 scalable

* Integrated some review feedback

* Incorporated review feedback

* Fix tests

* Made every module configurable

* Adapter to new lora test structure

* Final cleanup

* Some more final fixes

- Included examples in `using_peft_for_inference.md`
- Added hint that only attns are scaled
- Removed NoneTypes
- Added test to check mismatching lens of adapter names / weights raise error

* Update using_peft_for_inference.md

* Update using_peft_for_inference.md

* Make style, quality, fix-copies

* Updated tutorial;Warning if scale/adapter mismatch

* floats are forwarded as-is; changed tutorial scale

* make style, quality, fix-copies

* Fixed typo in tutorial

* Moved some warnings into `lora_loader_utils.py`

* Moved scale/lora mismatch warnings back

* Integrated final review suggestions

* Empty commit to trigger CI

* Reverted emoty commit to trigger CI

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
UmerHA
2024-03-29 16:45:57 +01:00
committed by GitHub
parent 4d39b7483d
commit 0302446819
7 changed files with 553 additions and 21 deletions

View File

@@ -133,6 +133,62 @@ image
![no-lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_20_1.png)
### Customize adapters strength
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].
For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
```python
pipe.enable_lora() # enable lora again, after we disabled it above
prompt = "toy_face of a hacker with a hoodie, pixel art"
adapter_weight_scales = { "unet": { "down": 1, "mid": 0, "up": 0} }
pipe.set_adapters("pixel", adapter_weight_scales)
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```
![block-lora-text-and-down](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_down.png)
Let's see how turning off the `down` part and turning on the `mid` and `up` part respectively changes the image.
```python
adapter_weight_scales = { "unet": { "down": 0, "mid": 1, "up": 0} }
pipe.set_adapters("pixel", adapter_weight_scales)
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```
![block-lora-text-and-mid](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mid.png)
```python
adapter_weight_scales = { "unet": { "down": 0, "mid": 0, "up": 1} }
pipe.set_adapters("pixel", adapter_weight_scales)
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```
![block-lora-text-and-up](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_up.png)
Looks cool!
This is a really powerful feature. You can use it to control the adapter strengths down to per-transformer level. And you can even use it for multiple adapters.
```python
adapter_weight_scales_toy = 0.5
adapter_weight_scales_pixel = {
"unet": {
"down": 0.9, # all transformers in the down-part will use scale 0.9
# "mid" # because, in this example, "mid" is not given, all transformers in the mid part will use the default scale 1.0
"up": {
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
}
}
}
pipe.set_adapters(["toy", "pixel"], [adapter_weight_scales_toy, adapter_weight_scales_pixel])
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
image
```
![block-lora-mixed](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peft_integration/diffusers_peft_lora_inference_block_mixed.png)
## Manage active adapters
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:

View File

@@ -153,18 +153,43 @@ image
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
</div>
<Tip>
For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
</Tip>
To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
```py
pipeline.unload_lora_weights()
```
### Adjust LoRA weight scale
For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
For more granular control on the amount of LoRA weights used per layer, you can use [`~loaders.LoraLoaderMixin.set_adapters`] and pass a dictionary specifying by how much to scale the weights in each layer by.
```python
pipe = ... # create pipeline
pipe.load_lora_weights(..., adapter_name="my_adapter")
scales = {
"text_encoder": 0.5,
"text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder
"unet": {
"down": 0.9, # all transformers in the down-part will use scale 0.9
# "mid" # in this example "mid" is not given, therefore all transformers in the mid part will use the default scale 1.0
"up": {
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
}
}
}
pipe.set_adapters("my_adapter", scales)
```
This also works with multiple adapters - see [this guide](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#customize-adapters-strength) for how to do it.
<Tip warning={true}>
Currently, [`~loaders.LoraLoaderMixin.set_adapters`] only supports scaling attention weights. If a LoRA has other parts (e.g., resnets or down-/upsamplers), they will keep a scale of 1.0.
</Tip>
### Kohya and TheLastBen
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import os
from pathlib import Path
@@ -985,7 +986,7 @@ class LoraLoaderMixin:
self,
adapter_names: Union[List[str], str],
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
text_encoder_weights: List[float] = None,
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
):
"""
Sets the adapter layers for the text encoder.
@@ -1003,15 +1004,20 @@ class LoraLoaderMixin:
raise ValueError("PEFT backend is required for this method.")
def process_weights(adapter_names, weights):
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights]
# Expand weights into a list, one entry per adapter
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
)
# Set None values to default of 1.0
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
weights = [w if w is not None else 1.0 for w in weights]
return weights
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
@@ -1059,17 +1065,77 @@ class LoraLoaderMixin:
def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[List[float]] = None,
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
adapter_weights = copy.deepcopy(adapter_weights)
# Expand weights into a list, one entry per adapter
if not isinstance(adapter_weights, list):
adapter_weights = [adapter_weights] * len(adapter_names)
if len(adapter_names) != len(adapter_weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
)
# Decompose weights into weights for unet, text_encoder and text_encoder_2
unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], []
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
all_adapters = {
adapter for adapters in list_adapters.values() for adapter in adapters
} # eg ["adapter1", "adapter2"]
invert_list_adapters = {
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
for adapter in all_adapters
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
unet_lora_weight = weights.pop("unet", None)
text_encoder_lora_weight = weights.pop("text_encoder", None)
text_encoder_2_lora_weight = weights.pop("text_encoder_2", None)
if len(weights) > 0:
raise ValueError(
f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}."
)
if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"):
logger.warning(
"Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2."
)
# warn if adapter doesn't have parts specified by adapter_weights
for part_weight, part_name in zip(
[unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight],
["uent", "text_encoder", "text_encoder_2"],
):
if part_weight is not None and part_name not in invert_list_adapters[adapter_name]:
logger.warning(
f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
)
else:
unet_lora_weight = weights
text_encoder_lora_weight = weights
text_encoder_2_lora_weight = weights
unet_lora_weights.append(unet_lora_weight)
text_encoder_lora_weights.append(text_encoder_lora_weight)
text_encoder_2_lora_weights.append(text_encoder_2_lora_weight)
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
# Handle the UNET
unet.set_adapters(adapter_names, adapter_weights)
unet.set_adapters(adapter_names, unet_lora_weights)
# Handle the Text Encoder
if hasattr(self, "text_encoder"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights)
if hasattr(self, "text_encoder_2"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights)
def disable_lora(self):
if not USE_PEFT_BACKEND:

View File

@@ -47,6 +47,7 @@ from .single_file_utils import (
infer_stable_cascade_single_file_config,
load_single_file_model_checkpoint,
)
from .unet_loader_utils import _maybe_expand_lora_scales
from .utils import AttnProcsLayers
@@ -564,7 +565,7 @@ class UNet2DConditionLoadersMixin:
def set_adapters(
self,
adapter_names: Union[List[str], str],
weights: Optional[Union[List[float], float]] = None,
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
):
"""
Set the currently active adapters for use in the UNet.
@@ -597,9 +598,9 @@ class UNet2DConditionLoadersMixin:
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
# Expand weights into a list, one entry per adapter
# examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
if not isinstance(weights, list):
weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
@@ -607,6 +608,13 @@ class UNet2DConditionLoadersMixin:
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
)
# Set None values to default of 1.0
# e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
weights = [w if w is not None else 1.0 for w in weights]
# e.g. [{...}, 7] -> [{expanded dict...}, 7]
weights = _maybe_expand_lora_scales(self, weights)
set_weights_and_activate_adapters(self, adapter_names, weights)
def disable_lora(self):

View File

@@ -0,0 +1,154 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import TYPE_CHECKING, Dict, List, Union
from ..utils import logging
if TYPE_CHECKING:
# import here to avoid circular imports
from ..models import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _translate_into_actual_layer_name(name):
"""Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
if name == "mid":
return "mid_block.attentions.0"
updown, block, attn = name.split(".")
updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
block = block.replace("block_", "")
attn = "attentions." + attn
return ".".join((updown, block, attn))
def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]]):
blocks_with_transformer = {
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
}
transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}
expanded_weight_scales = [
_maybe_expand_lora_scales_for_one_adapter(
weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict()
)
for weight_for_adapter in weight_scales
]
return expanded_weight_scales
def _maybe_expand_lora_scales_for_one_adapter(
scales: Union[float, Dict],
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
state_dict: None,
):
"""
Expands the inputs into a more granular dictionary. See the example below for more details.
Parameters:
scales (`Union[float, Dict]`):
Scales dict to expand.
blocks_with_transformer (`Dict[str, int]`):
Dict with keys 'up' and 'down', showing which blocks have transformer layers
transformer_per_block (`Dict[str, int]`):
Dict with keys 'up' and 'down', showing how many transformer layers each block has
E.g. turns
```python
scales = {
'down': 2,
'mid': 3,
'up': {
'block_0': 4,
'block_1': [5, 6, 7]
}
}
blocks_with_transformer = {
'down': [1,2],
'up': [0,1]
}
transformer_per_block = {
'down': 2,
'up': 3
}
```
into
```python
{
'down.block_1.0': 2,
'down.block_1.1': 2,
'down.block_2.0': 2,
'down.block_2.1': 2,
'mid': 3,
'up.block_0.0': 4,
'up.block_0.1': 4,
'up.block_0.2': 4,
'up.block_1.0': 5,
'up.block_1.1': 6,
'up.block_1.2': 7,
}
```
"""
if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")
if sorted(transformer_per_block.keys()) != ["down", "up"]:
raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")
if not isinstance(scales, dict):
# don't expand if scales is a single number
return scales
scales = copy.deepcopy(scales)
if "mid" not in scales:
scales["mid"] = 1
for updown in ["up", "down"]:
if updown not in scales:
scales[updown] = 1
# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
if not isinstance(scales[updown], dict):
scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]}
# eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
if not isinstance(scales[updown][block], list):
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
for tf_idx, value in enumerate(scales[updown][block]):
scales[f"{updown}.{block}.{tf_idx}"] = value
del scales[updown]
for layer in scales.keys():
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
raise ValueError(
f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
)
return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}

View File

@@ -230,16 +230,26 @@ def delete_adapter_layers(model, adapter_name):
def set_weights_and_activate_adapters(model, adapter_names, weights):
from peft.tuners.tuners_utils import BaseTunerLayer
def get_module_weight(weight_for_adapter, module_name):
if not isinstance(weight_for_adapter, dict):
# If weight_for_adapter is a single number, always return it.
return weight_for_adapter
for layer_name, weight_ in weight_for_adapter.items():
if layer_name in module_name:
return weight_
raise RuntimeError(f"No LoRA weight found for module {module_name}.")
# iterate over each adapter, make it active and set the corresponding scaling weight
for adapter_name, weight in zip(adapter_names, weights):
for module in model.modules():
for module_name, module in model.named_modules():
if isinstance(module, BaseTunerLayer):
# For backward compatbility with previous PEFT versions
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
module.set_scale(adapter_name, weight)
module.set_scale(adapter_name, get_module_weight(weight, module_name))
# set multiple active adapters
for module in model.modules():

View File

@@ -15,6 +15,7 @@
import os
import tempfile
import unittest
from itertools import product
import numpy as np
import torch
@@ -762,6 +763,218 @@ class PeftLoraLoaderMixinTests:
"output with no lora and output with lora disabled should give same results",
)
def test_simple_inference_with_text_unet_block_scale(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set differnt weights for different blocks (i.e. block lora)
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
pipe.set_adapters("adapter-1", weights_1)
output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
weights_2 = {"unet": {"up": 5}}
pipe.set_adapters("adapter-1", weights_2)
output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
"LoRA weights 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
"No adapter and LoRA weights 1 should give different results",
)
self.assertFalse(
np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
"No adapter and LoRA weights 2 should give different results",
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
def test_simple_inference_with_text_unet_multi_adapter_block_lora(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set differnt weights for different blocks (i.e. block lora)
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
scales_2 = {"unet": {"down": 5, "mid": 5}}
pipe.set_adapters("adapter-1", scales_1)
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters("adapter-2", scales_2)
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
# Fuse and unfuse should lead to the same results
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
# a mismatching number of adapter_names and adapter_weights should raise an error
with self.assertRaises(ValueError):
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
def test_simple_inference_with_text_unet_block_scale_for_all_dict_options(self):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
def updown_options(blocks_with_tf, layers_per_block, value):
"""
Generate every possible combination for how a lora weight dict for the up/down part can be.
E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
"""
num_val = value
list_val = [value] * layers_per_block
node_opts = [None, num_val, list_val]
node_opts_foreach_block = [node_opts] * len(blocks_with_tf)
updown_opts = [num_val]
for nodes in product(*node_opts_foreach_block):
if all(n is None for n in nodes):
continue
opt = {}
for b, n in zip(blocks_with_tf, nodes):
if n is not None:
opt["block_" + str(b)] = n
updown_opts.append(opt)
return updown_opts
def all_possible_dict_opts(unet, value):
"""
Generate every possible combination for how a lora weight dict can be.
E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
"""
down_blocks_with_tf = [i for i, d in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
up_blocks_with_tf = [i for i, u in enumerate(unet.up_blocks) if hasattr(u, "attentions")]
layers_per_block = unet.config.layers_per_block
text_encoder_opts = [None, value]
text_encoder_2_opts = [None, value]
mid_opts = [None, value]
down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)
opts = []
for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
if all(o is None for o in (t1, t2, d, m, u)):
continue
opt = {}
if t1 is not None:
opt["text_encoder"] = t1
if t2 is not None:
opt["text_encoder_2"] = t2
if all(o is None for o in (d, m, u)):
# no unet scaling
continue
opt["unet"] = {}
if d is not None:
opt["unet"]["down"] = d
if m is not None:
opt["unet"]["mid"] = m
if u is not None:
opt["unet"]["up"] = u
opts.append(opt)
return opts
components, text_lora_config, unet_lora_config = self.get_dummy_components(self.scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
# test if lora block scales can be set with this scale_dict
if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
del scale_dict["text_encoder_2"]
pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error
def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches