1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] depcrecate save_attn_procs(). (#10126)

depcrecate save_attn_procs().
This commit is contained in:
Sayak Paul
2024-12-07 02:08:57 +05:30
committed by GitHub
parent 188bca3084
commit fa3a9100be
2 changed files with 21 additions and 0 deletions

View File

@@ -492,6 +492,9 @@ class UNet2DConditionLoadersMixin:
)
state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
else:
deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
deprecate("save_attn_procs", "0.40.0", deprecation_message)
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")

View File

@@ -1119,6 +1119,24 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
), "Loading from a saved checkpoint should produce identical results."
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertWarns(FutureWarning) as warning:
model.save_attn_procs(tmpdirname)
warning_message = str(warning.warnings[0].message)
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):