From 8009272f48764a5ef3f90d7d400337f0b2e84f1d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 13 Sep 2023 10:01:37 +0100 Subject: [PATCH] [Tests and Docs] Add a test on serializing pipelines with components containing fused LoRA modules (#4962) * add: test to ensure pipelines can be saved with fused lora modules. * add docs about serialization with fused lora. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Empty-Commit * Update docs/source/en/training/lora.md Co-authored-by: Patrick von Platen --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Patrick von Platen --- docs/source/en/training/lora.md | 47 ++++++++++++++++++++++++++++-- tests/models/test_lora_layers.py | 49 ++++++++++++++++++++++++++------ 2 files changed, 85 insertions(+), 11 deletions(-) diff --git a/docs/source/en/training/lora.md b/docs/source/en/training/lora.md index 80b4c58b8a..dd7013c059 100644 --- a/docs/source/en/training/lora.md +++ b/docs/source/en/training/lora.md @@ -34,7 +34,7 @@ the attention layers of a language model is sufficient to obtain good downstream [cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. 🧨 Diffusers now supports finetuning with LoRA for [text-to-image generation](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) and [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora). This guide will show you how to do both. -If you'd like to store or share your model with the community, login to your Hugging Face account (create [one](hf.co/join) if you don't have one already): +If you'd like to store or share your model with the community, login to your Hugging Face account (create [one](https://hf.co/join) if you don't have one already): ```bash huggingface-cli login @@ -321,7 +321,7 @@ pipe.fuse_lora() generator = torch.manual_seed(0) images_fusion = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + "masterpiece, best quality, mountain", generator=generator, num_inference_steps=2 ).images # To work with a different `lora_scale`, first reverse the effects of `fuse_lora()`. @@ -333,7 +333,48 @@ pipe.fuse_lora(lora_scale=0.5) generator = torch.manual_seed(0) images_fusion = pipe( - "masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2 + "masterpiece, best quality, mountain", generator=generator, num_inference_steps=2 +).images +``` + +## Serializing pipelines with fused LoRA parameters + +Let's say you want to load the pipeline above that has its UNet fused with the LoRA parameters. You can easily do so by simply calling the `save_pretrained()` method on `pipe`. + +After loading the LoRA parameters into a pipeline, if you want to serialize the pipeline such that the affected model components are already fused with the LoRA parameters, you should: + +* call `fuse_lora()` on the pipeline with the desired `lora_scale`, given you've already loaded the LoRA parameters into it. +* call `save_pretrained()` on the pipeline. + +Here is a complete example: + +```python +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda") +lora_model_id = "hf-internal-testing/sdxl-1.0-lora" +lora_filename = "sd_xl_offset_example-lora_1.0.safetensors" +pipe.load_lora_weights(lora_model_id, weight_name=lora_filename) + +# First, fuse the LoRA parameters. +pipe.fuse_lora() + +# Then save. +pipe.save_pretrained("my-pipeline-with-fused-lora") +``` + +Now, you can load the pipeline and directly perform inference without having to load the LoRA parameters again: + +```python +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained("my-pipeline-with-fused-lora", torch_dtype=torch.float16).to("cuda") + +generator = torch.manual_seed(0) +images_fusion = pipe( + "masterpiece, best quality, mountain", generator=generator, num_inference_steps=2 ).images ``` diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 1d846b6cdb..9affb37aa5 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -965,15 +965,11 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionXLPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) - # sd_pipe.unet.set_default_attn_processor() sd_pipe.set_progress_bar_config(disable=None) _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) - images = sd_pipe( - **pipeline_inputs, - generator=torch.manual_seed(0), - ).images + images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images images_slice = images[0, -3:, -3:, -1] # Emulate training. @@ -993,9 +989,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) lora_images_scale_0_5 = sd_pipe( - **pipeline_inputs, - generator=torch.manual_seed(0), - cross_attention_kwargs={"scale": 0.5}, + **pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} ).images lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1] @@ -1017,6 +1011,45 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): images_slice, lora_image_slice_scale_0_5, atol=1e-03 ), "0.5 scale and no scale shouldn't match" + def test_save_load_fused_lora_modules(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionXLPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False) + + # Emulate training. + set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1) + set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1) + set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1) + + with tempfile.TemporaryDirectory() as tmpdirname: + StableDiffusionXLPipeline.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], + text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=True, + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) + sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) + + sd_pipe.fuse_lora() + lora_images_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images + lora_image_slice_fusion = lora_images_fusion[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + sd_pipe.save_pretrained(tmpdirname) + sd_pipe_loaded = StableDiffusionXLPipeline.from_pretrained(tmpdirname) + + loaded_lora_images = sd_pipe_loaded(**pipeline_inputs, generator=torch.manual_seed(0)).images + loaded_lora_image_slice = loaded_lora_images[0, -3:, -3:, -1] + + assert np.allclose( + lora_image_slice_fusion, loaded_lora_image_slice, atol=1e-03 + ), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth." + @slow @require_torch_gpu