mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add lora delete feature (#5738)
* add lora delete feature * added tests and changed condition * deal with corner cases * more corner cases * rename to `delete_adapter_layers` for consistency --------- Co-authored-by: younesbelkada <younesbelkada@gmail.com>
This commit is contained in:
committed by
GitHub
parent
069123f66e
commit
9c8eca702c
@@ -36,6 +36,7 @@ from .utils import (
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
@@ -752,6 +753,27 @@ class UNet2DConditionLoadersMixin:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
set_adapter_layers(self, enabled=True)
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
"""
|
||||
Deletes the LoRA layers of `adapter_name` for the unet.
|
||||
|
||||
Args:
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
The names of the adapter to delete. Can be a single string or a list of strings
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if isinstance(adapter_names, str):
|
||||
adapter_names = [adapter_names]
|
||||
|
||||
for adapter_name in adapter_names:
|
||||
delete_adapter_layers(self, adapter_name)
|
||||
|
||||
# Pop also the corresponding adapter from the config
|
||||
if hasattr(self, "peft_config"):
|
||||
self.peft_config.pop(adapter_name, None)
|
||||
|
||||
|
||||
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
@@ -2507,6 +2529,30 @@ class LoraLoaderMixin:
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
self.enable_lora_for_text_encoder(self.text_encoder_2)
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
"""
|
||||
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
|
||||
|
||||
Args:
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
The names of the adapter to delete. Can be a single string or a list of strings
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if isinstance(adapter_names, str):
|
||||
adapter_names = [adapter_names]
|
||||
|
||||
# Delete unet adapters
|
||||
self.unet.delete_adapters(adapter_names)
|
||||
|
||||
for adapter_name in adapter_names:
|
||||
# Delete text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
delete_adapter_layers(self.text_encoder, adapter_name)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
delete_adapter_layers(self.text_encoder_2, adapter_name)
|
||||
|
||||
def get_active_adapters(self) -> List[str]:
|
||||
"""
|
||||
Gets the list of the current active adapters.
|
||||
|
||||
@@ -89,6 +89,7 @@ from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
check_peft_version,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
recurse_remove_peft_layers,
|
||||
|
||||
@@ -180,6 +180,28 @@ def set_adapter_layers(model, enabled=True):
|
||||
module.disable_adapters = not enabled
|
||||
|
||||
|
||||
def delete_adapter_layers(model, adapter_name):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if hasattr(module, "delete_adapter"):
|
||||
module.delete_adapter(adapter_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
|
||||
)
|
||||
|
||||
# For transformers integration - we need to pop the adapter from the config
|
||||
if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"):
|
||||
model.peft_config.pop(adapter_name, None)
|
||||
# In case all adapters are deleted, we need to delete the config
|
||||
# and make sure to set the flag to False
|
||||
if len(model.peft_config) == 0:
|
||||
del model.peft_config
|
||||
model._hf_peft_config_loaded = None
|
||||
|
||||
|
||||
def set_weights_and_activate_adapters(model, adapter_names, weights):
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
|
||||
@@ -831,6 +831,96 @@ class PeftLoraLoaderMixinTests:
|
||||
"output with no lora and output with lora disabled should give same results",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, attaches
|
||||
multiple adapters and set/delete them
|
||||
"""
|
||||
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
|
||||
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(self.torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
|
||||
|
||||
self.assertTrue(
|
||||
self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
|
||||
|
||||
if self.has_two_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
|
||||
self.assertTrue(
|
||||
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
pipe.set_adapters("adapter-1")
|
||||
|
||||
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters("adapter-2")
|
||||
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
|
||||
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
|
||||
"Adapter 1 and 2 should give different results",
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
|
||||
"Adapter 1 and mixed adapters should give different results",
|
||||
)
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
|
||||
"Adapter 2 and mixed adapters should give different results",
|
||||
)
|
||||
|
||||
pipe.delete_adapters("adapter-1")
|
||||
output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
|
||||
"Adapter 1 and 2 should give different results",
|
||||
)
|
||||
|
||||
pipe.delete_adapters("adapter-2")
|
||||
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
|
||||
"output with no lora and output with lora disabled should give same results",
|
||||
)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
pipe.delete_adapters(["adapter-1", "adapter-2"])
|
||||
|
||||
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
|
||||
"output with no lora and output with lora disabled should give same results",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
|
||||
"""
|
||||
Tests a simple inference with lora attached to text encoder and unet, attaches
|
||||
|
||||
Reference in New Issue
Block a user