mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Feat] add: utility for unloading lora. (#4034)
* add: test for testing unloading lora. * add :reason to skipif. * initial implementation of lora unload(). * apply styling. * add: doc. * change checkpoints. * reinit generator * finalize slow test. * add fast test for unloading lora.
This commit is contained in:
@@ -280,6 +280,10 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is
|
||||
**Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs,
|
||||
refer to the respective docstrings.
|
||||
|
||||
## Unloading LoRA parameters
|
||||
|
||||
You can call [`~diffusers.loaders.LoraLoaderMixin.unload_lora_weights`] on a pipeline to unload the LoRA parameters.
|
||||
|
||||
## Supporting A1111 themed LoRA checkpoints from Diffusers
|
||||
|
||||
To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted
|
||||
|
||||
@@ -25,6 +25,8 @@ from torch import nn
|
||||
from .models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
@@ -1270,6 +1272,38 @@ class LoraLoaderMixin:
|
||||
new_state_dict = {**unet_state_dict, **te_state_dict}
|
||||
return new_state_dict, network_alpha
|
||||
|
||||
def unload_lora_weights(self):
|
||||
"""
|
||||
Unloads the LoRA parameters.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
||||
>>> pipeline.unload_lora_weights()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
is_unet_lora = all(
|
||||
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor, LoRAAttnAddedKVProcessor))
|
||||
for _, processor in self.unet.attn_processors.items()
|
||||
)
|
||||
# Handle attention processors that are a mix of regular attention and AddedKV
|
||||
# attention.
|
||||
if is_unet_lora:
|
||||
is_attn_procs_mixed = all(
|
||||
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor))
|
||||
for _, processor in self.unet.attn_processors.items()
|
||||
)
|
||||
if not is_attn_procs_mixed:
|
||||
unet_attn_proc_cls = AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
|
||||
self.unet.set_attn_processor(unet_attn_proc_cls())
|
||||
else:
|
||||
self.unet.set_default_attn_processor()
|
||||
|
||||
# Safe to call the following regardless of LoRA.
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
|
||||
@@ -83,9 +83,9 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module):
|
||||
return text_encoder_lora_layers
|
||||
|
||||
|
||||
def set_lora_weights(text_lora_attn_parameters, randn_weight=False):
|
||||
def set_lora_weights(lora_attn_parameters, randn_weight=False):
|
||||
with torch.no_grad():
|
||||
for parameter in text_lora_attn_parameters:
|
||||
for parameter in lora_attn_parameters:
|
||||
if randn_weight:
|
||||
parameter[:] = torch.randn_like(parameter)
|
||||
else:
|
||||
@@ -155,7 +155,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
}
|
||||
return pipeline_components, lora_components
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 10
|
||||
num_channels = 4
|
||||
@@ -167,16 +167,16 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
|
||||
|
||||
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
|
||||
def get_dummy_tokens(self):
|
||||
max_seq_length = 77
|
||||
|
||||
@@ -399,6 +399,45 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
)
|
||||
self.assertIsInstance(module.processor, attn_proc_class)
|
||||
|
||||
def test_unload_lora(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
|
||||
sd_pipe = StableDiffusionPipeline(**pipeline_components)
|
||||
|
||||
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
|
||||
orig_image_slice = original_images[0, -3:, -3:, -1]
|
||||
|
||||
# Emulate training.
|
||||
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
|
||||
set_lora_weights(lora_components["text_encoder_lora_layers"].parameters(), randn_weight=True)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=lora_components["unet_lora_layers"],
|
||||
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
sd_pipe.load_lora_weights(tmpdirname)
|
||||
|
||||
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
|
||||
lora_image_slice = lora_images[0, -3:, -3:, -1]
|
||||
|
||||
# Unload LoRA parameters.
|
||||
sd_pipe.unload_lora_weights()
|
||||
original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
|
||||
orig_image_slice_two = original_images_two[0, -3:, -3:, -1]
|
||||
|
||||
assert not np.allclose(
|
||||
orig_image_slice, lora_image_slice
|
||||
), "LoRA parameters should lead to a different image slice."
|
||||
assert not np.allclose(
|
||||
orig_image_slice_two, lora_image_slice
|
||||
), "LoRA parameters should lead to a different image slice."
|
||||
assert np.allclose(
|
||||
orig_image_slice, orig_image_slice_two, atol=1e-3
|
||||
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
|
||||
def test_lora_unet_attn_processors_with_xformers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -537,3 +576,35 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583])
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-4))
|
||||
|
||||
def test_unload_lora(self):
|
||||
generator = torch.manual_seed(0)
|
||||
prompt = "masterpiece, best quality, mountain"
|
||||
num_inference_steps = 2
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to(
|
||||
torch_device
|
||||
)
|
||||
initial_images = pipe(
|
||||
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
|
||||
).images
|
||||
initial_images = initial_images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
lora_model_id = "hf-internal-testing/civitai-colored-icons-lora"
|
||||
lora_filename = "Colored_Icons_by_vizsumit.safetensors"
|
||||
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
lora_images = pipe(
|
||||
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
|
||||
).images
|
||||
lora_images = lora_images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
generator = torch.manual_seed(0)
|
||||
unloaded_lora_images = pipe(
|
||||
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
|
||||
).images
|
||||
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
self.assertFalse(np.allclose(initial_images, lora_images))
|
||||
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user