mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Tests] QoL improvements to the LoRA test suite (#10304)
* misc lora test improvements. * updates * fixes to tests
This commit is contained in:
@@ -36,7 +36,6 @@ from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
require_big_gpu_with_torch_cuda,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
@@ -331,7 +330,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -340,85 +340,32 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_B_bias(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
# Testing opposite direction where the LoRA params are zero-padded.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# keep track of the bias values of the base layers to perform checks later.
|
||||
bias_values = {}
|
||||
for name, module in pipe.transformer.named_modules():
|
||||
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
if module.bias is not None:
|
||||
bias_values[name] = module.bias.data.clone()
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
denoiser_lora_config.lora_bias = False
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
pipe.delete_adapters("adapter-1")
|
||||
|
||||
denoiser_lora_config.lora_bias = True
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
|
||||
# for now this is flux control lora specific but can be generalized later and added to ./utils.py
|
||||
def test_correct_lora_configs_with_different_ranks(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
# change the rank_pattern
|
||||
updated_rank = denoiser_lora_config.r * 2
|
||||
denoiser_lora_config.rank_pattern = {"single_transformer_blocks.0.attn.to_k": updated_rank}
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
assert pipe.transformer.peft_config["adapter-1"].rank_pattern == {
|
||||
"single_transformer_blocks.0.attn.to_k": updated_rank
|
||||
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
|
||||
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
# similarly change the alpha_pattern
|
||||
updated_alpha = denoiser_lora_config.lora_alpha * 2
|
||||
denoiser_lora_config.alpha_pattern = {"single_transformer_blocks.0.attn.to_k": updated_alpha}
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
assert pipe.transformer.peft_config["adapter-1"].alpha_pattern == {
|
||||
"single_transformer_blocks.0.attn.to_k": updated_alpha
|
||||
}
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
|
||||
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_lora_expanding_shape_with_normal_lora(self):
|
||||
# This test checks if it works when a lora with expanded shapes (like control loras) but
|
||||
# another lora with correct shapes is loaded. The opposite direction isn't supported and is
|
||||
# tested with it.
|
||||
def test_normal_lora_with_expanded_lora_raises_error(self):
|
||||
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
|
||||
# load shape expanded LoRA (such as Control LoRA).
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -26,18 +24,12 @@ from diffusers import (
|
||||
LTXPipeline,
|
||||
LTXVideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
is_torch_version,
|
||||
require_peft_backend,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@@ -107,41 +99,6 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@skip_mps
|
||||
@pytest.mark.xfail(
|
||||
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=True,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
pipe.transformer.transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
|
||||
out = pipe(
|
||||
"test", num_inference_steps=2, max_sequence_length=inputs["max_sequence_length"], output_type="np"
|
||||
)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
|
||||
|
||||
@@ -1988,3 +1988,113 @@ class PeftLoraLoaderMixinTests:
|
||||
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results as set_adapters().",
|
||||
)
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_B_bias(self):
|
||||
# Currently, this test is only relevant for Flux Control LoRA as we are not
|
||||
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# keep track of the bias values of the base layers to perform checks later.
|
||||
bias_values = {}
|
||||
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
for name, module in denoiser.named_modules():
|
||||
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
if module.bias is not None:
|
||||
bias_values[name] = module.bias.data.clone()
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
denoiser_lora_config.lora_bias = False
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_false_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
pipe.delete_adapters("adapter-1")
|
||||
|
||||
denoiser_lora_config.lora_bias = True
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
lora_bias_true_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
|
||||
|
||||
def test_correct_lora_configs_with_different_ranks(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.delete_adapters("adapter-1")
|
||||
else:
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
for name, _ in denoiser.named_modules():
|
||||
if "to_k" in name and "attn" in name and "lora" not in name:
|
||||
module_name_to_rank_update = name.replace(".base_layer.", ".")
|
||||
break
|
||||
|
||||
# change the rank_pattern
|
||||
updated_rank = denoiser_lora_config.r * 2
|
||||
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
updated_rank_pattern = pipe.unet.peft_config["adapter-1"].rank_pattern
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
|
||||
|
||||
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
|
||||
|
||||
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.delete_adapters("adapter-1")
|
||||
else:
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
# similarly change the alpha_pattern
|
||||
updated_alpha = denoiser_lora_config.lora_alpha * 2
|
||||
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
|
||||
if self.unet_kwargs is not None:
|
||||
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
pipe.unet.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
|
||||
)
|
||||
else:
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
|
||||
)
|
||||
|
||||
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
Reference in New Issue
Block a user