mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix: missing AutoencoderKL lora adapter (#9807)
* fix: missing AutoencoderKL lora adapter * fix --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
30f2e9bd20
commit
963ffca434
@@ -17,6 +17,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import deprecate
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
@@ -34,7 +35,7 @@ from ..modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
|
||||
@@ -36,7 +36,9 @@ from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
is_peft_available,
|
||||
load_hf_numpy,
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
require_torch_gpu,
|
||||
@@ -50,6 +52,10 @@ from diffusers.utils.torch_utils import randn_tensor
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@@ -263,6 +269,38 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
@require_peft_backend
|
||||
def test_lora_adapter(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
vae = self.model_class(**init_dict)
|
||||
|
||||
target_modules_vae = [
|
||||
"conv1",
|
||||
"conv2",
|
||||
"conv_in",
|
||||
"conv_shortcut",
|
||||
"conv",
|
||||
"conv_out",
|
||||
"skip_conv_1",
|
||||
"skip_conv_2",
|
||||
"skip_conv_3",
|
||||
"skip_conv_4",
|
||||
"to_k",
|
||||
"to_q",
|
||||
"to_v",
|
||||
"to_out.0",
|
||||
]
|
||||
vae_lora_config = LoraConfig(
|
||||
r=16,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules_vae,
|
||||
)
|
||||
|
||||
vae.add_adapter(vae_lora_config, adapter_name="vae_lora")
|
||||
active_lora = vae.active_adapters()
|
||||
self.assertTrue(len(active_lora) == 1)
|
||||
self.assertTrue(active_lora[0] == "vae_lora")
|
||||
|
||||
|
||||
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AsymmetricAutoencoderKL
|
||||
|
||||
Reference in New Issue
Block a user