1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] improve LoRA fusion tests (#11274)

* improve lora fusion tests

* more improvements.

* remove comment

* update

* relax tolerance.

* num_fused_loras as a property

Co-authored-by: BenjaminBossan <benjamin.bossan@gmail.com>

* updates

* update

* fix

* fix

Co-authored-by: BenjaminBossan <benjamin.bossan@gmail.com>

* Update src/diffusers/loaders/lora_base.py

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

---------

Co-authored-by: BenjaminBossan <benjamin.bossan@gmail.com>
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
Sayak Paul
2025-05-27 09:02:12 -07:00
committed by GitHub
parent 5939ace91b
commit a4da216125
4 changed files with 150 additions and 31 deletions

View File

@@ -465,7 +465,7 @@ class LoraBaseMixin:
"""Utility class for handling LoRAs."""
_lora_loadable_modules = []
num_fused_loras = 0
_merged_adapters = set()
def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")
@@ -592,6 +592,9 @@ class LoraBaseMixin:
if len(components) == 0:
raise ValueError("`components` cannot be an empty list.")
# Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it
# in `self._merged_adapters = self._merged_adapters | merged_adapter_names`.
merged_adapter_names = set()
for fuse_component in components:
if fuse_component not in self._lora_loadable_modules:
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
@@ -601,13 +604,19 @@ class LoraBaseMixin:
# check if diffusers model
if issubclass(model.__class__, ModelMixin):
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
for module in model.modules():
if isinstance(module, BaseTunerLayer):
merged_adapter_names.update(set(module.merged_adapters))
# handle transformers models.
if issubclass(model.__class__, PreTrainedModel):
fuse_text_encoder_lora(
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
for module in model.modules():
if isinstance(module, BaseTunerLayer):
merged_adapter_names.update(set(module.merged_adapters))
self.num_fused_loras += 1
self._merged_adapters = self._merged_adapters | merged_adapter_names
def unfuse_lora(self, components: List[str] = [], **kwargs):
r"""
@@ -661,9 +670,18 @@ class LoraBaseMixin:
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
for adapter in set(module.merged_adapters):
if adapter and adapter in self._merged_adapters:
self._merged_adapters = self._merged_adapters - {adapter}
module.unmerge()
self.num_fused_loras -= 1
@property
def num_fused_loras(self):
return len(self._merged_adapters)
@property
def fused_loras(self):
return self._merged_adapters
def set_adapters(
self,

View File

@@ -124,6 +124,9 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass

View File

@@ -117,6 +117,40 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def test_multiple_wrong_adapter_name_raises_error(self):
super().test_multiple_wrong_adapter_name_raises_error()
def test_simple_inference_with_text_denoiser_lora_unfused(self):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
else:
expected_atol = 1e-3
expected_rtol = 1e-3
super().test_simple_inference_with_text_denoiser_lora_unfused(
expected_atol=expected_atol, expected_rtol=expected_rtol
)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
else:
expected_atol = 1e-3
expected_rtol = 1e-3
super().test_simple_inference_with_text_lora_denoiser_fused_multi(
expected_atol=expected_atol, expected_rtol=expected_rtol
)
def test_lora_scale_kwargs_match_fusion(self):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
else:
expected_atol = 1e-3
expected_rtol = 1e-3
super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol)
@slow
@nightly

View File

@@ -80,6 +80,18 @@ def initialize_dummy_state_dict(state_dict):
POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]
def determine_attention_kwargs_name(pipeline_class):
call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys()
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
return attention_kwargs_name
@require_peft_backend
class PeftLoraLoaderMixinTests:
pipeline_class = None
@@ -442,14 +454,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected
"""
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
@@ -740,12 +745,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected
"""
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
@@ -878,9 +878,11 @@ class PeftLoraLoaderMixinTests:
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
# unloading should remove the LoRA layers
@@ -1608,26 +1610,21 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
# Attach a second adapter
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
# set them to multi-adapter inference mode
pipe.set_adapters(["adapter-1", "adapter-2"])
@@ -1637,6 +1634,7 @@ class PeftLoraLoaderMixinTests:
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers so outpout should remain the same
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -1647,9 +1645,23 @@ class PeftLoraLoaderMixinTests:
)
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
)
self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -1657,6 +1669,63 @@ class PeftLoraLoaderMixinTests:
np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for lora_scale in [1.0, 0.8]:
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly set in text encoder 2",
)
pipe.set_adapters(["adapter-1"])
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules,
adapter_names=["adapter-1"],
lora_scale=lora_scale,
)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
self.assertFalse(
np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
"LoRA should change the output",
)
@require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self):
@@ -1838,12 +1907,7 @@ class PeftLoraLoaderMixinTests:
def test_set_adapters_match_attention_kwargs(self):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)