From 3aba99af8fc0f7eb0b8be65c9769755ed43209f5 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Tue, 26 Dec 2023 16:54:47 +0100 Subject: [PATCH] =?UTF-8?q?[`Peft`=20/=20`Lora`]=C2=A0Add=20`adapter=5Fnam?= =?UTF-8?q?es`=20in=20`fuse=5Flora`=20=20(#5823)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add adapter_name in fuse * add tesrt * up * fix CI * adapt from suggestion * Update src/diffusers/utils/testing_utils.py Co-authored-by: Benjamin Bossan * change to `require_peft_version_greater` * change variable names in test * Update src/diffusers/loaders/lora.py Co-authored-by: Benjamin Bossan * break into 2 lines * final comments --------- Co-authored-by: Sayak Paul Co-authored-by: Benjamin Bossan --- .../en/tutorials/using_peft_for_inference.md | 23 +++++++ src/diffusers/loaders/lora.py | 50 ++++++++++++--- src/diffusers/loaders/unet.py | 31 +++++++-- src/diffusers/utils/testing_utils.py | 17 +++++ tests/lora/test_lora_layers_peft.py | 63 +++++++++++++++++++ 5 files changed, 173 insertions(+), 11 deletions(-) diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md index 6f317a7610..35b36b0ab2 100644 --- a/docs/source/en/tutorials/using_peft_for_inference.md +++ b/docs/source/en/tutorials/using_peft_for_inference.md @@ -183,3 +183,26 @@ image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).ima # Gets the Unet back to the original state pipe.unfuse_lora() ``` + +You can also fuse some adapters using `adapter_names` for faster generation: + +```py +pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") +pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy") + +pipe.set_adapters(["pixel"], adapter_weights=[0.5, 1.0]) +# Fuses the LoRAs into the Unet +pipe.fuse_lora(adapter_names=["pixel"]) + +prompt = "a hacker with a hoodie, pixel art" +image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0] + +# Gets the Unet back to the original state +pipe.unfuse_lora() + +# Fuse all adapters +pipe.fuse_lora(adapter_names=["pixel", "toy"]) + +prompt = "toy_face of a hacker with a hoodie, pixel art" +image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0] +``` diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 2ceff743da..bbd01a9950 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os from contextlib import nullcontext from typing import Callable, Dict, List, Optional, Union @@ -1001,6 +1002,7 @@ class LoraLoaderMixin: fuse_text_encoder: bool = True, lora_scale: float = 1.0, safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, ): r""" Fuses the LoRA parameters into the original parameters of the corresponding blocks. @@ -1020,6 +1022,21 @@ class LoraLoaderMixin: Controls how much to influence the outputs with the LoRA parameters. safe_fusing (`bool`, defaults to `False`): Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused. + + Example: + + ```py + from diffusers import DiffusionPipeline + import torch + + pipeline = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ).to("cuda") + pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") + pipeline.fuse_lora(lora_scale=0.7) + ``` """ if fuse_unet or fuse_text_encoder: self.num_fused_loras += 1 @@ -1030,24 +1047,43 @@ class LoraLoaderMixin: if fuse_unet: unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) + unet.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) if USE_PEFT_BACKEND: from peft.tuners.tuners_utils import BaseTunerLayer - def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False): - # TODO(Patrick, Younes): enable "safe" fusing + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): + merge_kwargs = {"safe_merge": safe_fusing} + for module in text_encoder.modules(): if isinstance(module, BaseTunerLayer): if lora_scale != 1.0: module.scale_layer(lora_scale) - module.merge() + # For BC with previous PEFT versions, we need to check the signature + # of the `merge` method to see if it supports the `adapter_names` argument. + supported_merge_kwargs = list(inspect.signature(module.merge).parameters) + if "adapter_names" in supported_merge_kwargs: + merge_kwargs["adapter_names"] = adapter_names + elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: + raise ValueError( + "The `adapter_names` argument is not supported with your PEFT version. " + "Please upgrade to the latest version of PEFT. `pip install -U peft`" + ) + + module.merge(**merge_kwargs) else: deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE) - def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False): + def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, **kwargs): + if "adapter_names" in kwargs and kwargs["adapter_names"] is not None: + raise ValueError( + "The `adapter_names` argument is not supported in your environment. Please switch to PEFT " + "backend to use this argument by installing latest PEFT and transformers." + " `pip install -U peft transformers`" + ) + for _, attn_module in text_encoder_attn_modules(text_encoder): if isinstance(attn_module.q_proj, PatchedLoraProjection): attn_module.q_proj._fuse_lora(lora_scale, safe_fusing) @@ -1062,9 +1098,9 @@ class LoraLoaderMixin: if fuse_text_encoder: if hasattr(self, "text_encoder"): - fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing) + fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing, adapter_names=adapter_names) if hasattr(self, "text_encoder_2"): - fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing) + fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing, adapter_names=adapter_names) def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): r""" diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7dec43571b..5d4c7429e4 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -11,9 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import inspect import os from collections import defaultdict from contextlib import nullcontext +from functools import partial from typing import Callable, Dict, List, Optional, Union import safetensors @@ -504,22 +506,43 @@ class UNet2DConditionLoadersMixin: save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") - def fuse_lora(self, lora_scale=1.0, safe_fusing=False): + def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None): self.lora_scale = lora_scale self._safe_fusing = safe_fusing - self.apply(self._fuse_lora_apply) + self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names)) - def _fuse_lora_apply(self, module): + def _fuse_lora_apply(self, module, adapter_names=None): if not USE_PEFT_BACKEND: if hasattr(module, "_fuse_lora"): module._fuse_lora(self.lora_scale, self._safe_fusing) + + if adapter_names is not None: + raise ValueError( + "The `adapter_names` argument is not supported in your environment. Please switch" + " to PEFT backend to use this argument by installing latest PEFT and transformers." + " `pip install -U peft transformers`" + ) else: from peft.tuners.tuners_utils import BaseTunerLayer + merge_kwargs = {"safe_merge": self._safe_fusing} + if isinstance(module, BaseTunerLayer): if self.lora_scale != 1.0: module.scale_layer(self.lora_scale) - module.merge(safe_merge=self._safe_fusing) + + # For BC with prevous PEFT versions, we need to check the signature + # of the `merge` method to see if it supports the `adapter_names` argument. + supported_merge_kwargs = list(inspect.signature(module.merge).parameters) + if "adapter_names" in supported_merge_kwargs: + merge_kwargs["adapter_names"] = adapter_names + elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None: + raise ValueError( + "The `adapter_names` argument is not supported with your PEFT version. Please upgrade" + " to the latest version of PEFT. `pip install -U peft`" + ) + + module.merge(**merge_kwargs) def unfuse_lora(self): self.apply(self._unfuse_lora_apply) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 606980f8a3..df1a4fc420 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -300,6 +300,23 @@ def require_peft_backend(test_case): return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case) +def require_peft_version_greater(peft_version): + """ + Decorator marking a test that requires PEFT backend with a specific version, this would require some specific + versions of PEFT and transformers. + """ + + def decorator(test_case): + correct_peft_version = is_peft_available() and version.parse( + version.parse(importlib.metadata.version("peft")).base_version + ) > version.parse(peft_version) + return unittest.skipUnless( + correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}" + )(test_case) + + return decorator + + def deprecate_after_peft_backend(test_case): """ Decorator marking a test that will be skipped after PEFT backend diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 38e55b9ed7..c139e0d6ea 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -50,6 +50,7 @@ from diffusers.utils.testing_utils import ( nightly, numpy_cosine_similarity_distance, require_peft_backend, + require_peft_version_greater, require_torch_gpu, slow, torch_device, @@ -1105,6 +1106,68 @@ class PeftLoraLoaderMixinTests: {"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]}, ) + @require_peft_version_greater(peft_version="0.6.2") + def test_simple_inference_with_text_lora_unet_fused_multi(self): + """ + Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model + and makes sure it works as expected - with unet and multi-adapter case + """ + for scheduler_cls in [DDIMScheduler, LCMScheduler]: + components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue(output_no_lora.shape == (1, 64, 64, 3)) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.unet.add_adapter(unet_lora_config, "adapter-1") + + # Attach a second adapter + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + pipe.unet.add_adapter(unet_lora_config, "adapter-2") + + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + # set them to multi-adapter inference mode + pipe.set_adapters(["adapter-1", "adapter-2"]) + ouputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.set_adapters(["adapter-1"]) + ouputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.fuse_lora(adapter_names=["adapter-1"]) + + # Fusing should still keep the LoRA layers so outpout should remain the same + outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertTrue( + np.allclose(ouputs_lora_1, outputs_lora_1_fused, atol=1e-3, rtol=1e-3), + "Fused lora should not change the output", + ) + + pipe.unfuse_lora() + pipe.fuse_lora(adapter_names=["adapter-2", "adapter-1"]) + + # Fusing should still keep the LoRA layers + output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0)).images + self.assertTrue( + np.allclose(output_all_lora_fused, ouputs_all_lora, atol=1e-3, rtol=1e-3), + "Fused lora should not change the output", + ) + @unittest.skip("This is failing for now - need to investigate") def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self): """