1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add function to remove monkey-patch for text encoder LoRA (#3649)

* merge undoable-monkeypatch

* remove TEXT_ENCODER_TARGET_MODULES, refactoring

* move create_lora_weight_file
This commit is contained in:
Takuma Mori
2023-06-06 17:36:13 +09:00
committed by GitHub
parent a8b0f42c38
commit b45204ea5a
4 changed files with 93 additions and 38 deletions

View File

@@ -34,7 +34,7 @@ from .models.attention_processor import (
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
TEXT_ENCODER_TARGET_MODULES,
TEXT_ENCODER_ATTN_MODULE,
_get_model_file,
deprecate,
is_safetensors_available,
@@ -955,6 +955,19 @@ class LoraLoaderMixin:
return self._text_encoder_lora_attn_procs
return
def _remove_text_encoder_monkey_patch(self):
# Loop over the CLIPAttention module of text_encoder
for name, attn_module in self.text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
# Loop over the LoRA layers
for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
# Retrieve the q/k/v/out projection of CLIPAttention
module = attn_module.get_submodule(text_encoder_attr)
if hasattr(module, "old_forward"):
# restore original `forward` to remove monkey-patch
module.forward = module.old_forward
delattr(module, "old_forward")
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
@@ -963,37 +976,41 @@ class LoraLoaderMixin:
attn_processors: Dict[str, `LoRAAttnProcessor`]:
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
"""
# Loop over the original attention modules.
for name, _ in self.text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
# Retrieve the module and its corresponding LoRA processor.
module = self.text_encoder.get_submodule(name)
# Construct a new function that performs the LoRA merging. We will monkey patch
# this forward pass.
attn_processor_name = ".".join(name.split(".")[:-1])
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
old_forward = module.forward
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)
# First, remove any monkey-patch that might have been applied before
self._remove_text_encoder_monkey_patch()
return new_forward
# Loop over the CLIPAttention module of text_encoder
for name, attn_module in self.text_encoder.named_modules():
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
# Loop over the LoRA layers
for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
# Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
module = attn_module.get_submodule(text_encoder_attr)
lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
# Monkey-patch.
module.forward = make_new_forward(old_forward, lora_layer)
# save old_forward to module that can be used to remove monkey-patch
old_forward = module.old_forward = module.forward
def _get_lora_layer_attribute(self, name: str) -> str:
if "q_proj" in name:
return "to_q_lora"
elif "v_proj" in name:
return "to_v_lora"
elif "k_proj" in name:
return "to_k_lora"
else:
return "to_out_lora"
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)
return new_forward
# Monkey-patch.
module.forward = make_new_forward(old_forward, lora_layer)
@property
def _lora_attn_processor_attr_to_text_encoder_attr(self):
return {
"to_q_lora": "q_proj",
"to_k_lora": "k_proj",
"to_v_lora": "v_proj",
"to_out_lora": "out_proj",
}
def _load_text_encoder_attn_procs(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs

View File

@@ -31,7 +31,6 @@ from .constants import (
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
TEXT_ENCODER_ATTN_MODULE,
TEXT_ENCODER_TARGET_MODULES,
WEIGHTS_NAME,
)
from .deprecation_utils import deprecate

View File

@@ -30,5 +30,4 @@ DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
TEXT_ENCODER_ATTN_MODULE = ".self_attn"

View File

@@ -163,6 +163,15 @@ class LoraLoaderMixinTests(unittest.TestCase):
return noise, input_ids, pipeline_inputs
def create_lora_weight_file(self, tmpdirname):
_, lora_components = self.get_dummy_components()
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
def test_lora_save_load(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**pipeline_components)
@@ -299,14 +308,45 @@ class LoraLoaderMixinTests(unittest.TestCase):
outputs_without_lora, outputs_with_lora
), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"
def create_lora_weight_file(self, tmpdirname):
_, lora_components = self.get_dummy_components()
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
def test_text_encoder_lora_remove_monkey_patch(self):
pipeline_components, _ = self.get_dummy_components()
pipe = StableDiffusionPipeline(**pipeline_components)
dummy_tokens = self.get_dummy_tokens()
# inference without lora
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora.shape == (1, 77, 32)
# create lora_attn_procs with randn up.weights
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
set_lora_up_weights(text_attn_procs, randn_weight=True)
# monkey patch
pipe._modify_text_encoder(text_attn_procs)
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
del text_attn_procs
gc.collect()
# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_with_lora.shape == (1, 77, 32)
assert not torch.allclose(
outputs_without_lora, outputs_with_lora
), "lora outputs should be different to without lora outputs"
# remove monkey patch
pipe._remove_text_encoder_monkey_patch()
# inference with removed lora
outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0]
assert outputs_without_lora_removed.shape == (1, 77, 32)
assert torch.allclose(
outputs_without_lora, outputs_without_lora_removed
), "remove lora monkey patch should restore the original outputs"
def test_lora_unet_attn_processors(self):
with tempfile.TemporaryDirectory() as tmpdirname: