From 963ffca43419a8dffa682d9e03c2299b76feeced Mon Sep 17 00:00:00 2001 From: Emmanuel Benazera Date: Tue, 3 Dec 2024 04:10:20 +0100 Subject: [PATCH] fix: missing AutoencoderKL lora adapter (#9807) * fix: missing AutoencoderKL lora adapter * fix --------- Co-authored-by: Sayak Paul --- .../models/autoencoders/autoencoder_kl.py | 3 +- tests/models/autoencoders/test_models_vae.py | 38 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py index 99a7da4a0b..9036c027a5 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl.py @@ -17,6 +17,7 @@ import torch import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import deprecate from ...utils.accelerate_utils import apply_forward_hook @@ -34,7 +35,7 @@ from ..modeling_utils import ModelMixin from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder -class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): +class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): r""" A VAE model with KL loss for encoding images into latents and decoding latent representations into images. diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index d29defbf60..d475160cc7 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -36,7 +36,9 @@ from diffusers.utils.testing_utils import ( backend_empty_cache, enable_full_determinism, floats_tensor, + is_peft_available, load_hf_numpy, + require_peft_backend, require_torch_accelerator, require_torch_accelerator_with_fp16, require_torch_gpu, @@ -50,6 +52,10 @@ from diffusers.utils.torch_utils import randn_tensor from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +if is_peft_available(): + from peft import LoraConfig + + enable_full_determinism() @@ -263,6 +269,38 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) + @require_peft_backend + def test_lora_adapter(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + vae = self.model_class(**init_dict) + + target_modules_vae = [ + "conv1", + "conv2", + "conv_in", + "conv_shortcut", + "conv", + "conv_out", + "skip_conv_1", + "skip_conv_2", + "skip_conv_3", + "skip_conv_4", + "to_k", + "to_q", + "to_v", + "to_out.0", + ] + vae_lora_config = LoraConfig( + r=16, + init_lora_weights="gaussian", + target_modules=target_modules_vae, + ) + + vae.add_adapter(vae_lora_config, adapter_name="vae_lora") + active_lora = vae.active_adapters() + self.assertTrue(len(active_lora) == 1) + self.assertTrue(active_lora[0] == "vae_lora") + class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = AsymmetricAutoencoderKL