mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] improve LoRA fusion tests (#11274)
* improve lora fusion tests * more improvements. * remove comment * update * relax tolerance. * num_fused_loras as a property Co-authored-by: BenjaminBossan <benjamin.bossan@gmail.com> * updates * update * fix * fix Co-authored-by: BenjaminBossan <benjamin.bossan@gmail.com> * Update src/diffusers/loaders/lora_base.py Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> --------- Co-authored-by: BenjaminBossan <benjamin.bossan@gmail.com> Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
@@ -465,7 +465,7 @@ class LoraBaseMixin:
|
||||
"""Utility class for handling LoRAs."""
|
||||
|
||||
_lora_loadable_modules = []
|
||||
num_fused_loras = 0
|
||||
_merged_adapters = set()
|
||||
|
||||
def load_lora_weights(self, **kwargs):
|
||||
raise NotImplementedError("`load_lora_weights()` is not implemented.")
|
||||
@@ -592,6 +592,9 @@ class LoraBaseMixin:
|
||||
if len(components) == 0:
|
||||
raise ValueError("`components` cannot be an empty list.")
|
||||
|
||||
# Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it
|
||||
# in `self._merged_adapters = self._merged_adapters | merged_adapter_names`.
|
||||
merged_adapter_names = set()
|
||||
for fuse_component in components:
|
||||
if fuse_component not in self._lora_loadable_modules:
|
||||
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
||||
@@ -601,13 +604,19 @@ class LoraBaseMixin:
|
||||
# check if diffusers model
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
merged_adapter_names.update(set(module.merged_adapters))
|
||||
# handle transformers models.
|
||||
if issubclass(model.__class__, PreTrainedModel):
|
||||
fuse_text_encoder_lora(
|
||||
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
merged_adapter_names.update(set(module.merged_adapters))
|
||||
|
||||
self.num_fused_loras += 1
|
||||
self._merged_adapters = self._merged_adapters | merged_adapter_names
|
||||
|
||||
def unfuse_lora(self, components: List[str] = [], **kwargs):
|
||||
r"""
|
||||
@@ -661,9 +670,18 @@ class LoraBaseMixin:
|
||||
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
for adapter in set(module.merged_adapters):
|
||||
if adapter and adapter in self._merged_adapters:
|
||||
self._merged_adapters = self._merged_adapters - {adapter}
|
||||
module.unmerge()
|
||||
|
||||
self.num_fused_loras -= 1
|
||||
@property
|
||||
def num_fused_loras(self):
|
||||
return len(self._merged_adapters)
|
||||
|
||||
@property
|
||||
def fused_loras(self):
|
||||
return self._merged_adapters
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
|
||||
@@ -124,6 +124,9 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
|
||||
def test_lora_scale_kwargs_match_fusion(self):
|
||||
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@@ -117,6 +117,40 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
def test_multiple_wrong_adapter_name_raises_error(self):
|
||||
super().test_multiple_wrong_adapter_name_raises_error()
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
else:
|
||||
expected_atol = 1e-3
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(
|
||||
expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
else:
|
||||
expected_atol = 1e-3
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(
|
||||
expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
)
|
||||
|
||||
def test_lora_scale_kwargs_match_fusion(self):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
else:
|
||||
expected_atol = 1e-3
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol)
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
|
||||
@@ -80,6 +80,18 @@ def initialize_dummy_state_dict(state_dict):
|
||||
POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]
|
||||
|
||||
|
||||
def determine_attention_kwargs_name(pipeline_class):
|
||||
call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys()
|
||||
|
||||
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
|
||||
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
|
||||
if possible_attention_kwargs in call_signature_keys:
|
||||
attention_kwargs_name = possible_attention_kwargs
|
||||
break
|
||||
assert attention_kwargs_name is not None
|
||||
return attention_kwargs_name
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class PeftLoraLoaderMixinTests:
|
||||
pipeline_class = None
|
||||
@@ -442,14 +454,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached on the text encoder + scale argument
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
|
||||
|
||||
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
|
||||
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
|
||||
if possible_attention_kwargs in call_signature_keys:
|
||||
attention_kwargs_name = possible_attention_kwargs
|
||||
break
|
||||
assert attention_kwargs_name is not None
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
||||
@@ -740,12 +745,7 @@ class PeftLoraLoaderMixinTests:
|
||||
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
|
||||
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
|
||||
if possible_attention_kwargs in call_signature_keys:
|
||||
attention_kwargs_name = possible_attention_kwargs
|
||||
break
|
||||
assert attention_kwargs_name is not None
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
@@ -878,9 +878,11 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
|
||||
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
|
||||
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
|
||||
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
|
||||
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
# unloading should remove the LoRA layers
|
||||
@@ -1608,26 +1610,21 @@ class PeftLoraLoaderMixinTests:
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
# Attach a second adapter
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
|
||||
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
|
||||
|
||||
# set them to multi-adapter inference mode
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"])
|
||||
@@ -1637,6 +1634,7 @@ class PeftLoraLoaderMixinTests:
|
||||
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
|
||||
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
|
||||
|
||||
# Fusing should still keep the LoRA layers so outpout should remain the same
|
||||
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -1647,9 +1645,23 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
|
||||
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
|
||||
)
|
||||
|
||||
pipe.fuse_lora(
|
||||
components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
|
||||
)
|
||||
self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
|
||||
|
||||
# Fusing should still keep the LoRA layers
|
||||
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -1657,6 +1669,63 @@ class PeftLoraLoaderMixinTests:
|
||||
np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
|
||||
"Fused lora should not change the output",
|
||||
)
|
||||
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
|
||||
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
|
||||
|
||||
def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
|
||||
for lora_scale in [1.0, 0.8]:
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2),
|
||||
"Lora not correctly set in text encoder 2",
|
||||
)
|
||||
|
||||
pipe.set_adapters(["adapter-1"])
|
||||
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
|
||||
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
|
||||
|
||||
pipe.fuse_lora(
|
||||
components=self.pipeline_class._lora_loadable_modules,
|
||||
adapter_names=["adapter-1"],
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
|
||||
|
||||
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
|
||||
"Fused lora should not change the output",
|
||||
)
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
|
||||
"LoRA should change the output",
|
||||
)
|
||||
|
||||
@require_peft_version_greater(peft_version="0.9.0")
|
||||
def test_simple_inference_with_dora(self):
|
||||
@@ -1838,12 +1907,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
def test_set_adapters_match_attention_kwargs(self):
|
||||
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
|
||||
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
|
||||
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
|
||||
if possible_attention_kwargs in call_signature_keys:
|
||||
attention_kwargs_name = possible_attention_kwargs
|
||||
break
|
||||
assert attention_kwargs_name is not None
|
||||
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
|
||||
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
|
||||
Reference in New Issue
Block a user