1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix: missing AutoencoderKL lora adapter (#9807)

* fix: missing AutoencoderKL lora adapter

* fix

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Emmanuel Benazera
2024-12-03 04:10:20 +01:00
committed by GitHub
parent 30f2e9bd20
commit 963ffca434
2 changed files with 40 additions and 1 deletions

View File

@@ -17,6 +17,7 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import deprecate
from ...utils.accelerate_utils import apply_forward_hook
@@ -34,7 +35,7 @@ from ..modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.

View File

@@ -36,7 +36,9 @@ from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
is_peft_available,
load_hf_numpy,
require_peft_backend,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
require_torch_gpu,
@@ -50,6 +52,10 @@ from diffusers.utils.torch_utils import randn_tensor
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
if is_peft_available():
from peft import LoraConfig
enable_full_determinism()
@@ -263,6 +269,38 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@require_peft_backend
def test_lora_adapter(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
vae = self.model_class(**init_dict)
target_modules_vae = [
"conv1",
"conv2",
"conv_in",
"conv_shortcut",
"conv",
"conv_out",
"skip_conv_1",
"skip_conv_2",
"skip_conv_3",
"skip_conv_4",
"to_k",
"to_q",
"to_v",
"to_out.0",
]
vae_lora_config = LoraConfig(
r=16,
init_lora_weights="gaussian",
target_modules=target_modules_vae,
)
vae.add_adapter(vae_lora_config, adapter_name="vae_lora")
active_lora = vae.active_adapters()
self.assertTrue(len(active_lora) == 1)
self.assertTrue(active_lora[0] == "vae_lora")
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL