mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Peft / Lora] Add adapter_names in fuse_lora (#5823)
* add adapter_name in fuse * add tesrt * up * fix CI * adapt from suggestion * Update src/diffusers/utils/testing_utils.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * change to `require_peft_version_greater` * change variable names in test * Update src/diffusers/loaders/lora.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * break into 2 lines * final comments --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
@@ -183,3 +183,26 @@ image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).ima
|
||||
# Gets the Unet back to the original state
|
||||
pipe.unfuse_lora()
|
||||
```
|
||||
|
||||
You can also fuse some adapters using `adapter_names` for faster generation:
|
||||
|
||||
```py
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
pipe.set_adapters(["pixel"], adapter_weights=[0.5, 1.0])
|
||||
# Fuses the LoRAs into the Unet
|
||||
pipe.fuse_lora(adapter_names=["pixel"])
|
||||
|
||||
prompt = "a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
|
||||
# Gets the Unet back to the original state
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# Fuse all adapters
|
||||
pipe.fuse_lora(adapter_names=["pixel", "toy"])
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
@@ -1001,6 +1002,7 @@ class LoraLoaderMixin:
|
||||
fuse_text_encoder: bool = True,
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
@@ -1020,6 +1022,21 @@ class LoraLoaderMixin:
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
if fuse_unet or fuse_text_encoder:
|
||||
self.num_fused_loras += 1
|
||||
@@ -1030,24 +1047,43 @@ class LoraLoaderMixin:
|
||||
|
||||
if fuse_unet:
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
|
||||
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
||||
# TODO(Patrick, Younes): enable "safe" fusing
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
module.merge()
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
else:
|
||||
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs):
|
||||
if "adapter_names" in kwargs and kwargs["adapter_names"] is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported in your environment. Please switch to PEFT "
|
||||
"backend to use this argument by installing latest PEFT and transformers."
|
||||
" `pip install -U peft transformers`"
|
||||
)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
@@ -1062,9 +1098,9 @@ class LoraLoaderMixin:
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing, adapter_names=adapter_names)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing, adapter_names=adapter_names)
|
||||
|
||||
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
|
||||
r"""
|
||||
|
||||
@@ -11,9 +11,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
@@ -504,22 +506,43 @@ class UNet2DConditionLoadersMixin:
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
self.lora_scale = lora_scale
|
||||
self._safe_fusing = safe_fusing
|
||||
self.apply(self._fuse_lora_apply)
|
||||
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
|
||||
|
||||
def _fuse_lora_apply(self, module):
|
||||
def _fuse_lora_apply(self, module, adapter_names=None):
|
||||
if not USE_PEFT_BACKEND:
|
||||
if hasattr(module, "_fuse_lora"):
|
||||
module._fuse_lora(self.lora_scale, self._safe_fusing)
|
||||
|
||||
if adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported in your environment. Please switch"
|
||||
" to PEFT backend to use this argument by installing latest PEFT and transformers."
|
||||
" `pip install -U peft transformers`"
|
||||
)
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
merge_kwargs = {"safe_merge": self._safe_fusing}
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if self.lora_scale != 1.0:
|
||||
module.scale_layer(self.lora_scale)
|
||||
module.merge(safe_merge=self._safe_fusing)
|
||||
|
||||
# For BC with prevous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
||||
" to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
def unfuse_lora(self):
|
||||
self.apply(self._unfuse_lora_apply)
|
||||
|
||||
@@ -300,6 +300,23 @@ def require_peft_backend(test_case):
|
||||
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
|
||||
|
||||
|
||||
def require_peft_version_greater(peft_version):
|
||||
"""
|
||||
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
|
||||
versions of PEFT and transformers.
|
||||
"""
|
||||
|
||||
def decorator(test_case):
|
||||
correct_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
) > version.parse(peft_version)
|
||||
return unittest.skipUnless(
|
||||
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
|
||||
@@ -50,6 +50,7 @@ from diffusers.utils.testing_utils import (
|
||||
nightly,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -1105,6 +1106,68 @@ class PeftLoraLoaderMixinTests:
|
||||
{"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]},
|
||||
)
|
||||
|
||||
@require_peft_version_greater(peft_version="0.6.2")
|
||||
def test_simple_inference_with_text_lora_unet_fused_multi(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
|
||||
and makes sure it works as expected - with unet and multi-adapter case
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
|
||||
|
||||
# Attach a second adapter
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
|
||||
|
||||
self.assertTrue(
|
||||
self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
|
||||
self.assertTrue(
|
||||
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
# set them to multi-adapter inference mode
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
ouputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters(["adapter-1"])
|
||||
ouputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.fuse_lora(adapter_names=["adapter-1"])
|
||||
|
||||
# Fusing should still keep the LoRA layers so outpout should remain the same
|
||||
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=1e-3, rtol=1e-3),
|
||||
"Fused lora should not change the output",
|
||||
)
|
||||
|
||||
pipe.unfuse_lora()
|
||||
pipe.fuse_lora(adapter_names=["adapter-2", "adapter-1"])
|
||||
|
||||
# Fusing should still keep the LoRA layers
|
||||
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(
|
||||
np.allclose(output_all_lora_fused, ouputs_all_lora, atol=1e-3, rtol=1e-3),
|
||||
"Fused lora should not change the output",
|
||||
)
|
||||
|
||||
@unittest.skip("This is failing for now - need to investigate")
|
||||
def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user