diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 4590c2452b..a34ec6acce 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -36,6 +36,7 @@ from .utils import ( convert_state_dict_to_diffusers, convert_state_dict_to_peft, convert_unet_state_dict_to_peft, + delete_adapter_layers, deprecate, get_adapter_name, get_peft_kwargs, @@ -752,6 +753,27 @@ class UNet2DConditionLoadersMixin: raise ValueError("PEFT backend is required for this method.") set_adapter_layers(self, enabled=True) + def delete_adapters(self, adapter_names: Union[List[str], str]): + """ + Deletes the LoRA layers of `adapter_name` for the unet. + + Args: + adapter_names (`Union[List[str], str]`): + The names of the adapter to delete. Can be a single string or a list of strings + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + for adapter_name in adapter_names: + delete_adapter_layers(self, adapter_name) + + # Pop also the corresponding adapter from the config + if hasattr(self, "peft_config"): + self.peft_config.pop(adapter_name, None) + def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs): cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) @@ -2507,6 +2529,30 @@ class LoraLoaderMixin: if hasattr(self, "text_encoder_2"): self.enable_lora_for_text_encoder(self.text_encoder_2) + def delete_adapters(self, adapter_names: Union[List[str], str]): + """ + Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s). + + Args: + adapter_names (`Union[List[str], str]`): + The names of the adapter to delete. Can be a single string or a list of strings + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + if isinstance(adapter_names, str): + adapter_names = [adapter_names] + + # Delete unet adapters + self.unet.delete_adapters(adapter_names) + + for adapter_name in adapter_names: + # Delete text encoder adapters + if hasattr(self, "text_encoder"): + delete_adapter_layers(self.text_encoder, adapter_name) + if hasattr(self, "text_encoder_2"): + delete_adapter_layers(self.text_encoder_2, adapter_name) + def get_active_adapters(self) -> List[str]: """ Gets the list of the current active adapters. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b4d6bdab33..c1385d5847 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -89,6 +89,7 @@ from .logging import get_logger from .outputs import BaseOutput from .peft_utils import ( check_peft_version, + delete_adapter_layers, get_adapter_name, get_peft_kwargs, recurse_remove_peft_layers, diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py index 158435a6e8..2bcbeb3b79 100644 --- a/src/diffusers/utils/peft_utils.py +++ b/src/diffusers/utils/peft_utils.py @@ -180,6 +180,28 @@ def set_adapter_layers(model, enabled=True): module.disable_adapters = not enabled +def delete_adapter_layers(model, adapter_name): + from peft.tuners.tuners_utils import BaseTunerLayer + + for module in model.modules(): + if isinstance(module, BaseTunerLayer): + if hasattr(module, "delete_adapter"): + module.delete_adapter(adapter_name) + else: + raise ValueError( + "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1" + ) + + # For transformers integration - we need to pop the adapter from the config + if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"): + model.peft_config.pop(adapter_name, None) + # In case all adapters are deleted, we need to delete the config + # and make sure to set the flag to False + if len(model.peft_config) == 0: + del model.peft_config + model._hf_peft_config_loaded = None + + def set_weights_and_activate_adapters(model, adapter_names, weights): from peft.tuners.tuners_utils import BaseTunerLayer diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index 68e986790d..1c651d53bf 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -831,6 +831,96 @@ class PeftLoraLoaderMixinTests: "output with no lora and output with lora disabled should give same results", ) + def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self): + """ + Tests a simple inference with lora attached to text encoder and unet, attaches + multiple adapters and set/delete them + """ + for scheduler_cls in [DDIMScheduler, LCMScheduler]: + components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + pipe = self.pipeline_class(**components) + pipe = pipe.to(self.torch_device) + pipe.set_progress_bar_config(disable=None) + _, _, inputs = self.get_dummy_inputs(with_generator=False) + + output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + + pipe.unet.add_adapter(unet_lora_config, "adapter-1") + pipe.unet.add_adapter(unet_lora_config, "adapter-2") + + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" + ) + self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet") + + if self.has_two_text_encoders: + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") + self.assertTrue( + self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" + ) + + pipe.set_adapters("adapter-1") + + output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.set_adapters("adapter-2") + output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images + + pipe.set_adapters(["adapter-1", "adapter-2"]) + + output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertFalse( + np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter 1 and 2 should give different results", + ) + + self.assertFalse( + np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 1 and mixed adapters should give different results", + ) + + self.assertFalse( + np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), + "Adapter 2 and mixed adapters should give different results", + ) + + pipe.delete_adapters("adapter-1") + output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertTrue( + np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), + "Adapter 1 and 2 should give different results", + ) + + pipe.delete_adapters("adapter-2") + output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertTrue( + np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) + + pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") + pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") + + pipe.unet.add_adapter(unet_lora_config, "adapter-1") + pipe.unet.add_adapter(unet_lora_config, "adapter-2") + + pipe.set_adapters(["adapter-1", "adapter-2"]) + pipe.delete_adapters(["adapter-1", "adapter-2"]) + + output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images + + self.assertTrue( + np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), + "output with no lora and output with lora disabled should give same results", + ) + def test_simple_inference_with_text_unet_multi_adapter_weighted(self): """ Tests a simple inference with lora attached to text encoder and unet, attaches