mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] depcrecate save_attn_procs(). (#10126)
depcrecate save_attn_procs().
This commit is contained in:
@@ -492,6 +492,9 @@ class UNet2DConditionLoadersMixin:
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
|
||||
else:
|
||||
deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
|
||||
deprecate("save_attn_procs", "0.40.0", deprecation_message)
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
|
||||
|
||||
|
||||
@@ -1119,6 +1119,24 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
|
||||
), "Loading from a saved checkpoint should produce identical results."
|
||||
|
||||
@require_peft_backend
|
||||
def test_save_attn_procs_raise_warning(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
|
||||
|
||||
|
||||
@slow
|
||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user