mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Tests and Docs] Add a test on serializing pipelines with components containing fused LoRA modules (#4962)
* add: test to ensure pipelines can be saved with fused lora modules. * add docs about serialization with fused lora. * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Empty-Commit * Update docs/source/en/training/lora.md Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -34,7 +34,7 @@ the attention layers of a language model is sufficient to obtain good downstream
|
||||
|
||||
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. 🧨 Diffusers now supports finetuning with LoRA for [text-to-image generation](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) and [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora). This guide will show you how to do both.
|
||||
|
||||
If you'd like to store or share your model with the community, login to your Hugging Face account (create [one](hf.co/join) if you don't have one already):
|
||||
If you'd like to store or share your model with the community, login to your Hugging Face account (create [one](https://hf.co/join) if you don't have one already):
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
@@ -321,7 +321,7 @@ pipe.fuse_lora()
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
images_fusion = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
"masterpiece, best quality, mountain", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
|
||||
# To work with a different `lora_scale`, first reverse the effects of `fuse_lora()`.
|
||||
@@ -333,7 +333,48 @@ pipe.fuse_lora(lora_scale=0.5)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
images_fusion = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
"masterpiece, best quality, mountain", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
```
|
||||
|
||||
## Serializing pipelines with fused LoRA parameters
|
||||
|
||||
Let's say you want to load the pipeline above that has its UNet fused with the LoRA parameters. You can easily do so by simply calling the `save_pretrained()` method on `pipe`.
|
||||
|
||||
After loading the LoRA parameters into a pipeline, if you want to serialize the pipeline such that the affected model components are already fused with the LoRA parameters, you should:
|
||||
|
||||
* call `fuse_lora()` on the pipeline with the desired `lora_scale`, given you've already loaded the LoRA parameters into it.
|
||||
* call `save_pretrained()` on the pipeline.
|
||||
|
||||
Here is a complete example:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
lora_model_id = "hf-internal-testing/sdxl-1.0-lora"
|
||||
lora_filename = "sd_xl_offset_example-lora_1.0.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
|
||||
# First, fuse the LoRA parameters.
|
||||
pipe.fuse_lora()
|
||||
|
||||
# Then save.
|
||||
pipe.save_pretrained("my-pipeline-with-fused-lora")
|
||||
```
|
||||
|
||||
Now, you can load the pipeline and directly perform inference without having to load the LoRA parameters again:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("my-pipeline-with-fused-lora", torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
images_fusion = pipe(
|
||||
"masterpiece, best quality, mountain", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
```
|
||||
|
||||
|
||||
@@ -965,15 +965,11 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
# sd_pipe.unet.set_default_attn_processor()
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
images = sd_pipe(
|
||||
**pipeline_inputs,
|
||||
generator=torch.manual_seed(0),
|
||||
).images
|
||||
images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
|
||||
images_slice = images[0, -3:, -3:, -1]
|
||||
|
||||
# Emulate training.
|
||||
@@ -993,9 +989,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
|
||||
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
lora_images_scale_0_5 = sd_pipe(
|
||||
**pipeline_inputs,
|
||||
generator=torch.manual_seed(0),
|
||||
cross_attention_kwargs={"scale": 0.5},
|
||||
**pipeline_inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
|
||||
).images
|
||||
lora_image_slice_scale_0_5 = lora_images_scale_0_5[0, -3:, -3:, -1]
|
||||
|
||||
@@ -1017,6 +1011,45 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
|
||||
images_slice, lora_image_slice_scale_0_5, atol=1e-03
|
||||
), "0.5 scale and no scale shouldn't match"
|
||||
|
||||
def test_save_load_fused_lora_modules(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
# Emulate training.
|
||||
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1)
|
||||
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1)
|
||||
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=lora_components["unet_lora_layers"],
|
||||
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
|
||||
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
|
||||
safe_serialization=True,
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
sd_pipe.fuse_lora()
|
||||
lora_images_fusion = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
|
||||
lora_image_slice_fusion = lora_images_fusion[0, -3:, -3:, -1]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd_pipe.save_pretrained(tmpdirname)
|
||||
sd_pipe_loaded = StableDiffusionXLPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
loaded_lora_images = sd_pipe_loaded(**pipeline_inputs, generator=torch.manual_seed(0)).images
|
||||
loaded_lora_image_slice = loaded_lora_images[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(
|
||||
lora_image_slice_fusion, loaded_lora_image_slice, atol=1e-03
|
||||
), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth."
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user