mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add function to remove monkey-patch for text encoder LoRA (#3649)
* merge undoable-monkeypatch * remove TEXT_ENCODER_TARGET_MODULES, refactoring * move create_lora_weight_file
This commit is contained in:
@@ -34,7 +34,7 @@ from .models.attention_processor import (
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
TEXT_ENCODER_TARGET_MODULES,
|
||||
TEXT_ENCODER_ATTN_MODULE,
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_safetensors_available,
|
||||
@@ -955,6 +955,19 @@ class LoraLoaderMixin:
|
||||
return self._text_encoder_lora_attn_procs
|
||||
return
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
# Loop over the CLIPAttention module of text_encoder
|
||||
for name, attn_module in self.text_encoder.named_modules():
|
||||
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
||||
# Loop over the LoRA layers
|
||||
for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
|
||||
# Retrieve the q/k/v/out projection of CLIPAttention
|
||||
module = attn_module.get_submodule(text_encoder_attr)
|
||||
if hasattr(module, "old_forward"):
|
||||
# restore original `forward` to remove monkey-patch
|
||||
module.forward = module.old_forward
|
||||
delattr(module, "old_forward")
|
||||
|
||||
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
|
||||
r"""
|
||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
||||
@@ -963,37 +976,41 @@ class LoraLoaderMixin:
|
||||
attn_processors: Dict[str, `LoRAAttnProcessor`]:
|
||||
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
|
||||
"""
|
||||
# Loop over the original attention modules.
|
||||
for name, _ in self.text_encoder.named_modules():
|
||||
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
|
||||
# Retrieve the module and its corresponding LoRA processor.
|
||||
module = self.text_encoder.get_submodule(name)
|
||||
# Construct a new function that performs the LoRA merging. We will monkey patch
|
||||
# this forward pass.
|
||||
attn_processor_name = ".".join(name.split(".")[:-1])
|
||||
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
|
||||
old_forward = module.forward
|
||||
|
||||
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
|
||||
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
|
||||
def make_new_forward(old_forward, lora_layer):
|
||||
def new_forward(x):
|
||||
return old_forward(x) + lora_layer(x)
|
||||
# First, remove any monkey-patch that might have been applied before
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
|
||||
return new_forward
|
||||
# Loop over the CLIPAttention module of text_encoder
|
||||
for name, attn_module in self.text_encoder.named_modules():
|
||||
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
|
||||
# Loop over the LoRA layers
|
||||
for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
|
||||
# Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
|
||||
module = attn_module.get_submodule(text_encoder_attr)
|
||||
lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
|
||||
|
||||
# Monkey-patch.
|
||||
module.forward = make_new_forward(old_forward, lora_layer)
|
||||
# save old_forward to module that can be used to remove monkey-patch
|
||||
old_forward = module.old_forward = module.forward
|
||||
|
||||
def _get_lora_layer_attribute(self, name: str) -> str:
|
||||
if "q_proj" in name:
|
||||
return "to_q_lora"
|
||||
elif "v_proj" in name:
|
||||
return "to_v_lora"
|
||||
elif "k_proj" in name:
|
||||
return "to_k_lora"
|
||||
else:
|
||||
return "to_out_lora"
|
||||
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
|
||||
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
|
||||
def make_new_forward(old_forward, lora_layer):
|
||||
def new_forward(x):
|
||||
return old_forward(x) + lora_layer(x)
|
||||
|
||||
return new_forward
|
||||
|
||||
# Monkey-patch.
|
||||
module.forward = make_new_forward(old_forward, lora_layer)
|
||||
|
||||
@property
|
||||
def _lora_attn_processor_attr_to_text_encoder_attr(self):
|
||||
return {
|
||||
"to_q_lora": "q_proj",
|
||||
"to_k_lora": "k_proj",
|
||||
"to_v_lora": "v_proj",
|
||||
"to_out_lora": "out_proj",
|
||||
}
|
||||
|
||||
def _load_text_encoder_attn_procs(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
|
||||
|
||||
@@ -31,7 +31,6 @@ from .constants import (
|
||||
ONNX_WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
TEXT_ENCODER_ATTN_MODULE,
|
||||
TEXT_ENCODER_TARGET_MODULES,
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
from .deprecation_utils import deprecate
|
||||
|
||||
@@ -30,5 +30,4 @@ DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
|
||||
TEXT_ENCODER_ATTN_MODULE = ".self_attn"
|
||||
|
||||
@@ -163,6 +163,15 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def create_lora_weight_file(self, tmpdirname):
|
||||
_, lora_components = self.get_dummy_components()
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=lora_components["unet_lora_layers"],
|
||||
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
|
||||
def test_lora_save_load(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPipeline(**pipeline_components)
|
||||
@@ -299,14 +308,45 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
outputs_without_lora, outputs_with_lora
|
||||
), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs"
|
||||
|
||||
def create_lora_weight_file(self, tmpdirname):
|
||||
_, lora_components = self.get_dummy_components()
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=lora_components["unet_lora_layers"],
|
||||
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
def test_text_encoder_lora_remove_monkey_patch(self):
|
||||
pipeline_components, _ = self.get_dummy_components()
|
||||
pipe = StableDiffusionPipeline(**pipeline_components)
|
||||
|
||||
dummy_tokens = self.get_dummy_tokens()
|
||||
|
||||
# inference without lora
|
||||
outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0]
|
||||
assert outputs_without_lora.shape == (1, 77, 32)
|
||||
|
||||
# create lora_attn_procs with randn up.weights
|
||||
text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder)
|
||||
set_lora_up_weights(text_attn_procs, randn_weight=True)
|
||||
|
||||
# monkey patch
|
||||
pipe._modify_text_encoder(text_attn_procs)
|
||||
|
||||
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
|
||||
del text_attn_procs
|
||||
gc.collect()
|
||||
|
||||
# inference with lora
|
||||
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
|
||||
assert outputs_with_lora.shape == (1, 77, 32)
|
||||
|
||||
assert not torch.allclose(
|
||||
outputs_without_lora, outputs_with_lora
|
||||
), "lora outputs should be different to without lora outputs"
|
||||
|
||||
# remove monkey patch
|
||||
pipe._remove_text_encoder_monkey_patch()
|
||||
|
||||
# inference with removed lora
|
||||
outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0]
|
||||
assert outputs_without_lora_removed.shape == (1, 77, 32)
|
||||
|
||||
assert torch.allclose(
|
||||
outputs_without_lora, outputs_without_lora_removed
|
||||
), "remove lora monkey patch should restore the original outputs"
|
||||
|
||||
def test_lora_unet_attn_processors(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
|
||||
Reference in New Issue
Block a user