From 692b7a907d64f9ca375eb09cc211e632b7767693 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 14 Jul 2023 16:30:18 +0530 Subject: [PATCH] [Feat] add: utility for unloading lora. (#4034) * add: test for testing unloading lora. * add :reason to skipif. * initial implementation of lora unload(). * apply styling. * add: doc. * change checkpoints. * reinit generator * finalize slow test. * add fast test for unloading lora. --- docs/source/en/training/lora.mdx | 4 ++ src/diffusers/loaders.py | 34 +++++++++++++ tests/models/test_lora_layers.py | 85 +++++++++++++++++++++++++++++--- 3 files changed, 116 insertions(+), 7 deletions(-) diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index dfb31c7ef8..f7cfa5a8ea 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -280,6 +280,10 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is **Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs, refer to the respective docstrings. +## Unloading LoRA parameters + +You can call [`~diffusers.loaders.LoraLoaderMixin.unload_lora_weights`] on a pipeline to unload the LoRA parameters. + ## Supporting A1111 themed LoRA checkpoints from Diffusers To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 561ae74073..81dcc3618c 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -25,6 +25,8 @@ from torch import nn from .models.attention_processor import ( AttnAddedKVProcessor, AttnAddedKVProcessor2_0, + AttnProcessor, + AttnProcessor2_0, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, LoRAAttnAddedKVProcessor, @@ -1270,6 +1272,38 @@ class LoraLoaderMixin: new_state_dict = {**unet_state_dict, **te_state_dict} return new_state_dict, network_alpha + def unload_lora_weights(self): + """ + Unloads the LoRA parameters. + + Examples: + + ```python + >>> # Assuming `pipeline` is already loaded with the LoRA parameters. + >>> pipeline.unload_lora_weights() + >>> ... + ``` + """ + is_unet_lora = all( + isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor, LoRAAttnAddedKVProcessor)) + for _, processor in self.unet.attn_processors.items() + ) + # Handle attention processors that are a mix of regular attention and AddedKV + # attention. + if is_unet_lora: + is_attn_procs_mixed = all( + isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor)) + for _, processor in self.unet.attn_processors.items() + ) + if not is_attn_procs_mixed: + unet_attn_proc_cls = AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + self.unet.set_attn_processor(unet_attn_proc_cls()) + else: + self.unet.set_default_attn_processor() + + # Safe to call the following regardless of LoRA. + self._remove_text_encoder_monkey_patch() + class FromSingleFileMixin: """ diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 3190a12389..1396561367 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -83,9 +83,9 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module): return text_encoder_lora_layers -def set_lora_weights(text_lora_attn_parameters, randn_weight=False): +def set_lora_weights(lora_attn_parameters, randn_weight=False): with torch.no_grad(): - for parameter in text_lora_attn_parameters: + for parameter in lora_attn_parameters: if randn_weight: parameter[:] = torch.randn_like(parameter) else: @@ -155,7 +155,7 @@ class LoraLoaderMixinTests(unittest.TestCase): } return pipeline_components, lora_components - def get_dummy_inputs(self): + def get_dummy_inputs(self, with_generator=True): batch_size = 1 sequence_length = 10 num_channels = 4 @@ -167,16 +167,16 @@ class LoraLoaderMixinTests(unittest.TestCase): pipeline_inputs = { "prompt": "A painting of a squirrel eating a burger", - "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": "numpy", + "output_type": "np", } + if with_generator: + pipeline_inputs.update({"generator": generator}) return noise, input_ids, pipeline_inputs - # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb - + # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb def get_dummy_tokens(self): max_seq_length = 77 @@ -399,6 +399,45 @@ class LoraLoaderMixinTests(unittest.TestCase): ) self.assertIsInstance(module.processor, attn_proc_class) + def test_unload_lora(self): + pipeline_components, lora_components = self.get_dummy_components() + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + sd_pipe = StableDiffusionPipeline(**pipeline_components) + + original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True) + set_lora_weights(lora_components["text_encoder_lora_layers"].parameters(), randn_weight=True) + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Unload LoRA parameters. + sd_pipe.unload_lora_weights() + original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + orig_image_slice_two = original_images_two[0, -3:, -3:, -1] + + assert not np.allclose( + orig_image_slice, lora_image_slice + ), "LoRA parameters should lead to a different image slice." + assert not np.allclose( + orig_image_slice_two, lora_image_slice + ), "LoRA parameters should lead to a different image slice." + assert np.allclose( + orig_image_slice, orig_image_slice_two, atol=1e-3 + ), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters." + @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") def test_lora_unet_attn_processors_with_xformers(self): with tempfile.TemporaryDirectory() as tmpdirname: @@ -537,3 +576,35 @@ class LoraIntegrationTests(unittest.TestCase): expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583]) self.assertTrue(np.allclose(images, expected, atol=1e-4)) + + def test_unload_lora(self): + generator = torch.manual_seed(0) + prompt = "masterpiece, best quality, mountain" + num_inference_steps = 2 + + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to( + torch_device + ) + initial_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + initial_images = initial_images[0, -3:, -3:, -1].flatten() + + lora_model_id = "hf-internal-testing/civitai-colored-icons-lora" + lora_filename = "Colored_Icons_by_vizsumit.safetensors" + + pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + lora_images = lora_images[0, -3:, -3:, -1].flatten() + + pipe.unload_lora_weights() + generator = torch.manual_seed(0) + unloaded_lora_images = pipe( + prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps + ).images + unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten() + + self.assertFalse(np.allclose(initial_images, lora_images)) + self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))