mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] add a test to ensure set_adapters() and attn kwargs outs match (#10110)
* add a test to ensure set_adapters() and attn kwargs outs match * remove print * fix * Apply suggestions from code review Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * assertFalse. --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
This commit is contained in:
@@ -76,6 +76,9 @@ def initialize_dummy_state_dict(state_dict):
|
||||
return {k: torch.randn(v.shape, device=torch_device, dtype=v.dtype) for k, v in state_dict.items()}
|
||||
|
||||
|
||||
POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class PeftLoraLoaderMixinTests:
|
||||
pipeline_class = None
|
||||
@@ -429,7 +432,7 @@ class PeftLoraLoaderMixinTests:
|
||||
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
|
||||
|
||||
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
|
||||
for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
|
||||
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
|
||||
if possible_attention_kwargs in call_signature_keys:
|
||||
attention_kwargs_name = possible_attention_kwargs
|
||||
break
|
||||
@@ -790,7 +793,7 @@ class PeftLoraLoaderMixinTests:
|
||||
and makes sure it works as expected
|
||||
"""
|
||||
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
|
||||
for possible_attention_kwargs in ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]:
|
||||
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
|
||||
if possible_attention_kwargs in call_signature_keys:
|
||||
attention_kwargs_name = possible_attention_kwargs
|
||||
break
|
||||
@@ -1885,3 +1888,88 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs()
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
def test_set_adapters_match_attention_kwargs(self):
|
||||
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
|
||||
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys()
|
||||
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
|
||||
if possible_attention_kwargs in call_signature_keys:
|
||||
attention_kwargs_name = possible_attention_kwargs
|
||||
break
|
||||
assert attention_kwargs_name is not None
|
||||
|
||||
for scheduler_cls in self.scheduler_classes:
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
lora_scale = 0.5
|
||||
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
|
||||
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
|
||||
"Lora + scale should change the output",
|
||||
)
|
||||
|
||||
pipe.set_adapters("default", lora_scale)
|
||||
output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
|
||||
"Lora + scale should change the output",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
|
||||
"Lora + scale should match the output of `set_adapters()`.",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
|
||||
)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
for module_name, module in modules_to_save.items():
|
||||
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
|
||||
|
||||
output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
|
||||
self.assertTrue(
|
||||
not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Lora + scale should change the output",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results as attention_kwargs.",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results as set_adapters().",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user