1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
Files
diffusers/tests/lora/utils.py
2025-12-04 20:07:15 +08:00

1983 lines
100 KiB
Python

# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import re
from itertools import product
import numpy as np
import pytest
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.hooks.group_offloading import _GROUP_OFFLOADING, apply_group_offloading
from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available
from ..testing_utils import (
CaptureLogger,
check_if_dicts_are_equal,
floats_tensor,
is_torch_version,
require_peft_backend,
require_peft_version_greater,
require_torch_accelerator,
require_transformers_version_greater,
skip_mps,
torch_device,
)
if is_peft_available():
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import get_peft_model_state_dict
def state_dicts_almost_equal(sd1, sd2):
sd1 = dict(sorted(sd1.items()))
sd2 = dict(sorted(sd2.items()))
models_are_equal = True
for ten1, ten2 in zip(sd1.values(), sd2.values()):
if (ten1 - ten2).abs().max() > 1e-3:
models_are_equal = False
return models_are_equal
def check_if_lora_correctly_set(model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str):
extracted = {
k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.")
}
check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"])
def initialize_dummy_state_dict(state_dict):
if not all((v.device.type == "meta" for _, v in state_dict.items())):
raise ValueError("`state_dict` has non-meta values.")
return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}
POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]
def determine_attention_kwargs_name(pipeline_class):
call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys()
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
return attention_kwargs_name
@require_peft_backend
class PeftLoraLoaderMixinTests:
pipeline_class = None
scheduler_cls = None
scheduler_kwargs = None
has_two_text_encoders = False
has_three_text_encoders = False
text_encoder_cls, text_encoder_id, text_encoder_subfolder = None, None, ""
text_encoder_2_cls, text_encoder_2_id, text_encoder_2_subfolder = None, None, ""
text_encoder_3_cls, text_encoder_3_id, text_encoder_3_subfolder = None, None, ""
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
unet_kwargs = None
transformer_cls = None
transformer_kwargs = None
vae_cls = AutoencoderKL
vae_kwargs = None
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
@property
def output_shape(self):
raise NotImplementedError
@pytest.fixture(scope="class")
def base_pipe_output(self):
return self._compute_baseline_output()
@pytest.fixture(scope="function")
def tmpdirname(self, tmp_path_factory):
return tmp_path_factory.mktemp("tmp")
@pytest.fixture(scope="function")
def pipe(self):
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
return pipe
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders:
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
scheduler_cls = scheduler_cls if scheduler_cls is not None else self.scheduler_cls
rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha
torch.manual_seed(0)
if self.unet_kwargs is not None:
unet = UNet2DConditionModel(**self.unet_kwargs)
else:
transformer = self.transformer_cls(**self.transformer_kwargs)
scheduler = scheduler_cls(**self.scheduler_kwargs)
torch.manual_seed(0)
vae = self.vae_cls(**self.vae_kwargs)
text_encoder = self.text_encoder_cls.from_pretrained(
self.text_encoder_id, subfolder=self.text_encoder_subfolder
)
tokenizer = self.tokenizer_cls.from_pretrained(self.tokenizer_id, subfolder=self.tokenizer_subfolder)
if self.text_encoder_2_cls is not None:
text_encoder_2 = self.text_encoder_2_cls.from_pretrained(
self.text_encoder_2_id, subfolder=self.text_encoder_2_subfolder
)
tokenizer_2 = self.tokenizer_2_cls.from_pretrained(
self.tokenizer_2_id, subfolder=self.tokenizer_2_subfolder
)
if self.text_encoder_3_cls is not None:
text_encoder_3 = self.text_encoder_3_cls.from_pretrained(
self.text_encoder_3_id, subfolder=self.text_encoder_3_subfolder
)
tokenizer_3 = self.tokenizer_3_cls.from_pretrained(
self.tokenizer_3_id, subfolder=self.tokenizer_3_subfolder
)
text_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=self.text_encoder_target_modules,
init_lora_weights=False,
use_dora=use_dora,
)
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=self.denoiser_target_modules,
init_lora_weights=False,
use_dora=use_dora,
)
pipeline_components = {
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
# Denoiser
if self.unet_kwargs is not None:
pipeline_components.update({"unet": unet})
elif self.transformer_kwargs is not None:
pipeline_components.update({"transformer": transformer})
# Remaining text encoders.
if self.text_encoder_2_cls is not None:
pipeline_components.update({"tokenizer_2": tokenizer_2, "text_encoder_2": text_encoder_2})
if self.text_encoder_3_cls is not None:
pipeline_components.update({"tokenizer_3": tokenizer_3, "text_encoder_3": text_encoder_3})
# Remaining stuff
init_params = inspect.signature(self.pipeline_class.__init__).parameters
if "safety_checker" in init_params:
pipeline_components.update({"safety_checker": None})
if "feature_extractor" in init_params:
pipeline_components.update({"feature_extractor": None})
if "image_encoder" in init_params:
pipeline_components.update({"image_encoder": None})
return pipeline_components, text_lora_config, denoiser_lora_config
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 5,
"guidance_scale": 6.0,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return (noise, input_ids, pipeline_inputs)
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
if denoiser_lora_config is not None:
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
else:
denoiser = None
if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
return pipe, denoiser
def _compute_baseline_output(self):
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
# explicitly.
_, _, inputs = self.get_dummy_inputs(with_generator=False)
return pipe(**inputs, generator=torch.manual_seed(0))[0]
def _get_lora_state_dicts(self, modules_to_save):
state_dicts = {}
for module_name, module in modules_to_save.items():
if module is not None:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
return state_dicts
def _get_lora_adapter_metadata(self, modules_to_save):
metadatas = {}
for module_name, module in modules_to_save.items():
if module is not None:
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
return metadatas
def _get_modules_to_save(self, pipe, has_denoiser=False):
modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
if (
"text_encoder" in lora_loadable_modules
and hasattr(pipe, "text_encoder")
and (getattr(pipe.text_encoder, "peft_config", None) is not None)
):
modules_to_save["text_encoder"] = pipe.text_encoder
if (
"text_encoder_2" in lora_loadable_modules
and hasattr(pipe, "text_encoder_2")
and getattr(pipe.text_encoder_2, "peft_config", None) is not None
):
modules_to_save["text_encoder_2"] = pipe.text_encoder_2
if has_denoiser:
if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
modules_to_save["unet"] = pipe.unet
if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
modules_to_save["transformer"] = pipe.transformer
return modules_to_save
def test_simple_inference(self, base_pipe_output):
"""
Tests a simple inference and makes sure it works as expected
"""
assert base_pipe_output.shape == self.output_shape
def test_simple_inference_with_text_lora(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
_, text_lora_config, _ = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output"
@require_peft_version_greater("0.13.1")
def test_low_cpu_mem_usage_with_injection(self, pipe):
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
assert "meta" in {p.device.type for p in pipe.text_encoder.parameters()}, (
"The LoRA params should be on 'meta' device."
)
te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
assert "meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, (
"No param should be on 'meta' device."
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
assert "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
assert "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
assert "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()}, (
"The LoRA params should be on 'meta' device."
)
te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
assert "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()}, (
"No param should be on 'meta' device."
)
_, _, inputs = self.get_dummy_inputs()
output_lora = pipe(**inputs)[0]
assert output_lora.shape == self.output_shape
@require_peft_version_greater("0.13.1")
@require_transformers_version_greater("4.45.2")
def test_low_cpu_mem_usage_with_loading(self, tmpdirname, pipe):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), (
"Loading from saved checkpoints should give same results."
)
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), (
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results."
)
def test_simple_inference_with_text_lora_and_scale(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected
"""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
_, text_lora_config, _ = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output"
attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
assert not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), (
"Lora + scale should change the output"
)
attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
assert np.allclose(base_pipe_output, output_lora_0_scale, atol=1e-3, rtol=1e-3), (
"Lora + 0 scale should lead to same result as no LoRA"
)
def test_simple_inference_with_text_lora_fused(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
_, text_lora_config, _ = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.fuse_lora()
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(ouput_fused, base_pipe_output, atol=1e-3, rtol=1e-3), (
"Fused lora should change the output"
)
def test_simple_inference_with_text_lora_unloaded(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected
"""
_, text_lora_config, _ = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.unload_lora_weights()
assert not check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
assert not check_if_lora_correctly_set(pipe.text_encoder_2), (
"Lora not correctly unloaded in text encoder 2"
)
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(ouput_unloaded, base_pipe_output, atol=1e-3, rtol=1e-3), (
"Unloading lora should match the base pipe output"
)
def test_simple_inference_with_text_lora_save_load(self, tmpdirname, pipe):
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
_, text_lora_config, _ = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
modules_to_save = self._get_modules_to_save(pipe)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), (
"Loading from saved checkpoints should give same results."
)
def test_simple_inference_with_partial_text_lora(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached on the text encoder
with different ranks and some adapters removed
and makes sure it works as expected
"""
text_lora_config = LoraConfig(
r=4,
rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
lora_alpha=4,
target_modules=self.text_encoder_target_modules,
init_lora_weights=False,
use_dora=False,
)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
state_dict = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
state_dict = {
f"text_encoder.{module_name}": param
for (module_name, param) in get_peft_model_state_dict(pipe.text_encoder).items()
if "text_model.encoder.layers.4" not in module_name
}
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
state_dict.update(
{
f"text_encoder_2.{module_name}": param
for (module_name, param) in get_peft_model_state_dict(pipe.text_encoder_2).items()
if "text_model.encoder.layers.4" not in module_name
}
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output"
pipe.unload_lora_weights()
pipe.load_lora_weights(state_dict)
output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), (
"Removing adapters should change the output"
)
def test_simple_inference_save_pretrained_with_text_lora(self, tmpdirname, pipe):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
_, text_lora_config, _ = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), (
"Lora not correctly set in text encoder"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
assert check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2), (
"Lora not correctly set in text encoder 2"
)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), (
"Loading from saved checkpoints should give same results."
)
def test_simple_inference_with_text_denoiser_lora_save_load(self, tmpdirname, pipe):
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), (
"Loading from saved checkpoints should give same results."
)
def test_simple_inference_with_text_denoiser_lora_and_scale(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected
"""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_lora, base_pipe_output, atol=1e-3, rtol=1e-3), "Lora should change the output"
attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
assert not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), (
"Lora + scale should change the output"
)
attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
assert np.allclose(base_pipe_output, output_lora_0_scale, atol=1e-3, rtol=1e-3), (
"Lora + 0 scale should lead to same result as no LoRA"
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
assert pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0, (
"The scaling parameter has not been correctly restored!"
)
def test_simple_inference_with_text_lora_denoiser_fused(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser"
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_fused, base_pipe_output, atol=1e-3, rtol=1e-3), (
"Fused lora should change the output"
)
def test_simple_inference_with_text_denoiser_lora_unloaded(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unload_lora_weights()
assert not check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
assert not check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser"
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
assert not check_if_lora_correctly_set(pipe.text_encoder_2), (
"Lora not correctly unloaded in text encoder 2"
)
output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_unloaded, base_pipe_output, atol=1e-3, rtol=1e-3), (
"Fused lora should change the output"
)
def test_simple_inference_with_text_denoiser_lora_unfused(
self, pipe, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
assert pipe.num_fused_loras == 1, (
f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}"
)
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
assert pipe.num_fused_loras == 0, (
f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}"
)
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers"
assert check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers"
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
assert np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol), (
"Fused lora should not change the output"
)
def test_simple_inference_with_text_denoiser_multi_adapter(self, base_pipe_output):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(base_pipe_output, output_adapter_1, atol=1e-3, rtol=1e-3), (
"Adapter outputs should be different."
)
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(base_pipe_output, output_adapter_2, atol=1e-3, rtol=1e-3), (
"Adapter outputs should be different."
)
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(base_pipe_output, output_adapter_mixed, atol=1e-3, rtol=1e-3), (
"Adapter outputs should be different."
)
assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), (
"Adapter 1 and 2 should give different results"
)
assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), (
"Adapter 1 and mixed adapters should give different results"
)
assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), (
"Adapter 2 and mixed adapters should give different results"
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), (
"output with no lora and output with lora disabled should give same results"
)
def test_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)
with pytest.raises(ValueError) as err_context:
pipe.set_adapters("test")
assert "not in the list of present adapters" in str(err_context.value)
pipe.set_adapters(adapter_name)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_multiple_wrong_adapter_name_raises_error(self, pipe):
adapter_name = "adapter-1"
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
logger = logging.get_logger("diffusers.loaders.lora_base")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
wrong_components = sorted(set(scale_with_wrong_components.keys()))
msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
assert msg in str(cap_logger.out)
pipe.set_adapters(adapter_name)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_simple_inference_with_text_denoiser_block_scale(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set different weights for different blocks (i.e. block lora)
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
pipe.set_adapters("adapter-1", weights_1)
output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
weights_2 = {"unet": {"up": 5}}
pipe.set_adapters("adapter-1", weights_2)
output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), (
"LoRA weights 1 and 2 should give different results"
)
assert not np.allclose(base_pipe_output, output_weights_1, atol=1e-3, rtol=1e-3), (
"No adapter and LoRA weights 1 should give different results"
)
assert not np.allclose(base_pipe_output, output_weights_2, atol=1e-3, rtol=1e-3), (
"No adapter and LoRA weights 2 should give different results"
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), (
"output with no lora and output with lora disabled should give same results"
)
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set different weights for different blocks (i.e. block lora)
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
scales_2 = {"unet": {"down": 5, "mid": 5}}
pipe.set_adapters("adapter-1", scales_1)
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2", scales_2)
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), (
"Adapter 1 and 2 should give different results"
)
assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), (
"Adapter 1 and mixed adapters should give different results"
)
assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), (
"Adapter 2 and mixed adapters should give different results"
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), (
"output with no lora and output with lora disabled should give same results"
)
with pytest.raises(ValueError):
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self, pipe):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
def updown_options(blocks_with_tf, layers_per_block, value):
"""
Generate every possible combination for how a lora weight dict for the up/down part can be.
E.g. 2, {"block_1": 2}, {"block_1": [2,2,2]}, {"block_1": 2, "block_2": [2,2,2]}, ...
"""
num_val = value
list_val = [value] * layers_per_block
node_opts = [None, num_val, list_val]
node_opts_foreach_block = [node_opts] * len(blocks_with_tf)
updown_opts = [num_val]
for nodes in product(*node_opts_foreach_block):
if all((n is None for n in nodes)):
continue
opt = {}
for b, n in zip(blocks_with_tf, nodes):
if n is not None:
opt["block_" + str(b)] = n
updown_opts.append(opt)
return updown_opts
def all_possible_dict_opts(unet, value):
"""
Generate every possible combination for how a lora weight dict can be.
E.g. 2, {"unet: {"down": 2}}, {"unet: {"down": [2,2,2]}}, {"unet: {"mid": 2, "up": [2,2,2]}}, ...
"""
down_blocks_with_tf = [i for (i, d) in enumerate(unet.down_blocks) if hasattr(d, "attentions")]
up_blocks_with_tf = [i for (i, u) in enumerate(unet.up_blocks) if hasattr(u, "attentions")]
layers_per_block = unet.config.layers_per_block
text_encoder_opts = [None, value]
text_encoder_2_opts = [None, value]
mid_opts = [None, value]
down_opts = [None] + updown_options(down_blocks_with_tf, layers_per_block, value)
up_opts = [None] + updown_options(up_blocks_with_tf, layers_per_block + 1, value)
opts = []
for t1, t2, d, m, u in product(text_encoder_opts, text_encoder_2_opts, down_opts, mid_opts, up_opts):
if all((o is None for o in (t1, t2, d, m, u))):
continue
opt = {}
if t1 is not None:
opt["text_encoder"] = t1
if t2 is not None:
opt["text_encoder_2"] = t2
if all((o is None for o in (d, m, u))):
continue
opt["unet"] = {}
if d is not None:
opt["unet"]["down"] = d
if m is not None:
opt["unet"]["mid"] = m
if u is not None:
opt["unet"]["up"] = u
opts.append(opt)
return opts
_, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
# test if lora block scales can be set with this scale_dict
if not self.has_two_text_encoders and "text_encoder_2" in scale_dict:
del scale_dict["text_encoder_2"]
pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error
def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), (
"Adapter 1 and 2 should give different results"
)
assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), (
"Adapter 1 and mixed adapters should give different results"
)
assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), (
"Adapter 2 and mixed adapters should give different results"
)
pipe.delete_adapters("adapter-1")
output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), (
"Adapter 1 and 2 should give different results"
)
pipe.delete_adapters("adapter-2")
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(base_pipe_output, output_deleted_adapters, atol=1e-3, rtol=1e-3), (
"output with no lora and output with lora disabled should give same results"
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
pipe.set_adapters(["adapter-1", "adapter-2"])
pipe.delete_adapters(["adapter-1", "adapter-2"])
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(base_pipe_output, output_deleted_adapters, atol=1e-3, rtol=1e-3), (
"output with no lora and output with lora disabled should give same results"
)
def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self, base_pipe_output, pipe):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), (
"Adapter 1 and 2 should give different results"
)
assert not np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), (
"Adapter 1 and mixed adapters should give different results"
)
assert not np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), (
"Adapter 2 and mixed adapters should give different results"
)
pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), (
"Weighted adapter and mixed adapter should give different results"
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(base_pipe_output, output_disabled, atol=1e-3, rtol=1e-3), (
"output with no lora and output with lora disabled should give same results"
)
@skip_mps
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=False,
)
def test_lora_fuse_nan(self, pipe):
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
with torch.no_grad():
if self.unet_kwargs:
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
"inf"
)
else:
named_modules = [name for (name, _) in pipe.transformer.named_modules()]
possible_tower_names = [
"transformer_blocks",
"blocks",
"joint_transformer_blocks",
"single_transformer_blocks",
]
filtered_tower_names = [
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
]
if len(filtered_tower_names) == 0:
reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
raise ValueError(reason)
for tower_name in filtered_tower_names:
transformer_tower = getattr(pipe.transformer, tower_name)
has_attn1 = any(("attn1" in name for name in named_modules))
if has_attn1:
transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
else:
transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
with pytest.raises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
assert np.isnan(out).all()
def test_get_adapters(self, pipe):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
adapter_names = pipe.get_active_adapters()
assert adapter_names == ["adapter-1"]
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
adapter_names = pipe.get_active_adapters()
assert adapter_names == ["adapter-2"]
pipe.set_adapters(["adapter-1", "adapter-2"])
assert sorted(pipe.get_active_adapters()) == ["adapter-1", "adapter-2"]
def test_get_list_adapters(self, pipe):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
# 1.
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
dicts_to_be_checked = {"text_encoder": ["adapter-1"]}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
dicts_to_be_checked.update({"unet": ["adapter-1"]})
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
dicts_to_be_checked.update({"transformer": ["adapter-1"]})
assert pipe.get_list_adapters() == dicts_to_be_checked
# 2.
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
assert pipe.get_list_adapters() == dicts_to_be_checked
# 3.
pipe.set_adapters(["adapter-1", "adapter-2"])
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else:
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
assert pipe.get_list_adapters() == dicts_to_be_checked
# 4.
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
assert pipe.get_list_adapters() == dicts_to_be_checked
def test_simple_inference_with_text_lora_denoiser_fused_multi(
self,
pipe,
expected_atol: float = 1e-3,
expected_rtol: float = 1e-3,
):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
# set them to multi-adapter inference mode
pipe.set_adapters(["adapter-1", "adapter-2"])
outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1"])
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
assert pipe.num_fused_loras == 1, (
f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}"
)
# Fusing should still keep the LoRA layers so output should remain the same
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), (
"Fused lora should not change the output"
)
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
assert pipe.num_fused_loras == 0, (
f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}"
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
assert check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers"
assert check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers"
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"])
assert pipe.num_fused_loras == 2, (
f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}"
)
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), (
"Fused lora should not change the output"
)
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
assert pipe.num_fused_loras == 0, (
f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}"
)
def test_lora_scale_kwargs_match_fusion(
self, base_pipe_output, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
):
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for lora_scale in [1.0, 0.8]:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
pipe.set_adapters(["adapter-1"])
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules,
adapter_names=["adapter-1"],
lora_scale=lora_scale,
)
assert pipe.num_fused_loras == 1, (
f"pipe.num_fused_loras={pipe.num_fused_loras!r}, pipe.fused_loras={pipe.fused_loras!r}"
)
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), (
"Fused lora should not change the output"
)
assert not np.allclose(base_pipe_output, outputs_lora_1, atol=expected_atol, rtol=expected_rtol), (
"LoRA should change the output"
)
def test_simple_inference_with_dora(self, pipe):
_, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert output_no_dora_lora.shape == self.output_shape
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), (
"DoRA lora should change the output"
)
def test_missing_keys_warning(self, tmpdirname, pipe):
_, _, denoiser_lora_config = self.get_dummy_components()
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
pipe.unload_lora_weights()
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
missing_key = [k for k in state_dict if "lora_A" in k][0]
del state_dict[missing_key]
logger = logging.get_logger("diffusers.utils.peft_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
component = list({k.split(".")[0] for k in state_dict})[0]
assert missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "")
def test_unexpected_keys_warning(self, tmpdirname, pipe):
_, _, denoiser_lora_config = self.get_dummy_components()
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts)
pipe.unload_lora_weights()
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
state_dict = torch.load(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), weights_only=True)
unexpected_key = [k for k in state_dict if "lora_A" in k][0] + ".diffusers_cat"
state_dict[unexpected_key] = torch.tensor(1.0, device=torch_device)
logger = logging.get_logger("diffusers.utils.peft_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(state_dict)
assert ".diffusers_cat" in cap_logger.out
@pytest.mark.skip("This is failing for now - need to investigate")
def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self, pipe):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_modify_padding_mode(self, pipe):
def set_pad_mode(network, mode="circular"):
for _, module in network.named_modules():
if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode
_pad_mode = "circular"
set_pad_mode(pipe.vae, _pad_mode)
set_pad_mode(pipe.unet, _pad_mode)
_, _, inputs = self.get_dummy_inputs()
_ = pipe(**inputs)[0]
def test_logs_info_when_no_lora_keys_found(self, base_pipe_output, pipe):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
logger = logging.get_logger("diffusers.loaders.peft")
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(no_op_state_dict)
out_after_lora_attempt = pipe(**inputs, generator=torch.manual_seed(0))[0]
denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
assert cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}")
assert np.allclose(base_pipe_output, out_after_lora_attempt, atol=1e-05, rtol=1e-05)
for lora_module in self.pipeline_class._lora_loadable_modules:
if "text_encoder" in lora_module:
text_encoder = getattr(pipe, lora_module)
if lora_module == "text_encoder":
prefix = "text_encoder"
elif lora_module == "text_encoder_2":
prefix = "text_encoder_2"
logger = logging.get_logger("diffusers.loaders.lora_base")
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
self.pipeline_class.load_lora_into_text_encoder(
no_op_state_dict, network_alphas=None, text_encoder=text_encoder, prefix=prefix
)
assert cap_logger.out.startswith(f"No LoRA keys associated to {text_encoder.__class__.__name__}")
def test_set_adapters_match_attention_kwargs(self, base_pipe_output, tmpdirname, pipe):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
assert not np.allclose(base_pipe_output, output_lora_scale, atol=1e-3, rtol=1e-3), (
"Lora + scale should change the output"
)
pipe.set_adapters("default", lora_scale)
output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(base_pipe_output, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), (
"Lora + scale should change the output"
)
assert np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), (
"Lora + scale should match the output of `set_adapters()`."
)
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
for module_name, module in modules_to_save.items():
assert check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}"
output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
assert not np.allclose(base_pipe_output, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), (
"Lora + scale should change the output"
)
assert np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), (
"Loading from saved checkpoints should give same results as attention_kwargs."
)
assert np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3), (
"Loading from saved checkpoints should give same results as set_adapters()."
)
@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self, base_pipe_output, pipe):
_, _, denoiser_lora_config = self.get_dummy_components()
bias_values = {}
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, module in denoiser.named_modules():
if any((k in name for k in self.denoiser_target_modules)):
if module.bias is not None:
bias_values[name] = module.bias.data.clone()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
denoiser_lora_config.lora_bias = False
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.delete_adapters("adapter-1")
denoiser_lora_config.lora_bias = True
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(base_pipe_output, lora_bias_false_output, atol=1e-3, rtol=1e-3)
assert not np.allclose(base_pipe_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)
assert not np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)
def test_correct_lora_configs_with_different_ranks(self, base_pipe_output, pipe):
_, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
if self.unet_kwargs is not None:
pipe.unet.delete_adapters("adapter-1")
else:
pipe.transformer.delete_adapters("adapter-1")
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, _ in denoiser.named_modules():
if "to_k" in name and "attn" in name and ("lora" not in name):
module_name_to_rank_update = name.replace(".base_layer.", ".")
break
updated_rank = denoiser_lora_config.r * 2
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
assert updated_rank_pattern == {module_name_to_rank_update: updated_rank}
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(base_pipe_output, lora_output_same_rank, atol=1e-3, rtol=1e-3)
assert not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3)
if self.unet_kwargs is not None:
pipe.unet.delete_adapters("adapter-1")
else:
pipe.transformer.delete_adapters("adapter-1")
updated_alpha = denoiser_lora_config.lora_alpha * 2
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
assert pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == {
module_name_to_rank_update: updated_alpha
}
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(base_pipe_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3)
assert not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3)
def test_layerwise_casting_inference_denoiser(self):
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if "lora" in name or any((re.search(pattern, name) for pattern in patterns_to_check)):
dtype_to_check = compute_dtype
if getattr(submodule, "weight", None) is not None:
assert submodule.weight.dtype == dtype_to_check
if getattr(submodule, "bias", None) is not None:
assert submodule.bias.dtype == dtype_to_check
def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
check_linear_dtype(denoiser, storage_dtype, compute_dtype)
return pipe
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe_fp32 = initialize_pipeline(storage_dtype=None)
pipe_fp32(**inputs, generator=torch.manual_seed(0))[0]
pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
@require_peft_version_greater("0.14.0")
def test_layerwise_casting_peft_input_autocast_denoiser(self, tmpdirname):
"""
A test that checks if layerwise casting works correctly with PEFT layers and forward pass does not fail. This
is different from `test_layerwise_casting_inference_denoiser` as that disables the application of layerwise
cast hooks on the PEFT layers (relevant logic in `models.modeling_utils.ModelMixin.enable_layerwise_casting`).
In this test, we enable the layerwise casting on the PEFT layers as well. If run with PEFT version <= 0.14.0,
this test will fail with the following error:
```
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Float8_e4m3fn != float
```
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
"""
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import (
_PEFT_AUTOCAST_DISABLE_HOOK,
DEFAULT_SKIP_MODULES_PATTERN,
apply_layerwise_casting,
)
storage_dtype = torch.float8_e4m3fn
compute_dtype = torch.float32
def check_module(denoiser):
for name, module in denoiser.named_modules():
if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any((re.search(pattern, name) for pattern in patterns_to_check)):
dtype_to_check = compute_dtype
if getattr(module, "weight", None) is not None:
assert module.weight.dtype == dtype_to_check
if getattr(module, "bias", None) is not None:
assert module.bias.dtype == dtype_to_check
if isinstance(module, BaseTunerLayer):
assert getattr(module, "_diffusers_hook", None is not None)
assert module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(denoiser, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(denoiser._skip_layerwise_casting_patterns)
apply_layerwise_casting(
denoiser, storage_dtype=storage_dtype, compute_dtype=compute_dtype, skip_modules_pattern=patterns_to_check
)
check_module(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0]
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
apply_layerwise_casting(
denoiser,
storage_dtype=storage_dtype,
compute_dtype=compute_dtype,
skip_modules_pattern=patterns_to_check,
)
check_module(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0]
@pytest.mark.parametrize("lora_alpha", [4, 8, 16])
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha, tmpdirname, pipe):
_, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
out = pipe.lora_state_dict(tmpdirname, return_lora_metadata=True)
if len(out) == 3:
(_, _, parsed_metadata) = out
elif len(out) == 2:
(_, parsed_metadata) = out
denoiser_key = (
f"{self.pipeline_class.transformer_name}"
if self.transformer_kwargs is not None
else f"{self.pipeline_class.unet_name}"
)
assert any((k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
text_encoder_key = self.pipeline_class.text_encoder_name
assert any((k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
)
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_key = "text_encoder_2"
assert any((k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
)
@pytest.mark.parametrize("lora_alpha", [4, 8, 16])
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha, tmpdirname, pipe):
_, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdirname)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
def test_lora_unload_add_adapter(self, pipe):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.unload_lora_weights()
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_inference_load_delete_load_adapters(self, base_pipe_output, tmpdirname, pipe):
"""Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."""
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config)
assert check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe.delete_adapters(pipe.get_active_adapters()[0])
output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3)
assert np.allclose(base_pipe_output, output_no_adapter, atol=1e-3, rtol=1e-3)
pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
onload_device = torch_device
offload_device = torch.device("cpu")
_, _, denoiser_lora_config = self.get_dummy_components()
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
denoiser.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=1,
use_stream=use_stream,
)
for _, component in pipe.components.items():
if isinstance(component, torch.nn.Module):
component.to(torch_device)
group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
assert group_offload_hook_1 is not None
output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.unload_lora_weights()
group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
assert group_offload_hook_2 is not None
output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
check_if_lora_correctly_set(denoiser)
group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
assert group_offload_hook_3 is not None
output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3)
@pytest.mark.parametrize(
"offload_type, use_stream",
[("block_level", True), ("leaf_level", False), ("leaf_level", True)],
)
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
for cls in inspect.getmro(self.__class__):
if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
return
self._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
@pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")
def test_lora_loading_model_cpu_offload(self, tmpdirname, pipe):
_, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts)
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.enable_model_cpu_offload(device=torch_device)
pipe.load_lora_weights(tmpdirname)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3)
@require_torch_accelerator
def test_lora_group_offloading_delete_adapters(self, tmpdirname):
components, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
try:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.to(torch_device)
# Enable Group Offloading (leaf_level for more granular testing)
apply_group_offloading(
denoiser,
onload_device=torch_device,
offload_device="cpu",
offload_type="leaf_level",
)
pipe.load_lora_weights(tmpdirname, adapter_name="default")
out_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Delete the adapter
pipe.delete_adapters("default")
out_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(out_lora, out_no_lora, atol=1e-3, rtol=1e-3)
finally:
# Clean up the hooks to prevent state leak
if hasattr(denoiser, "_diffusers_hook"):
denoiser._diffusers_hook.remove_hook(_GROUP_OFFLOADING, recurse=True)