1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
Files
diffusers/tests/lora/test_deprecated_utilities.py
Sayak Paul 13e8fdecda [feat] add load_lora_adapter() for compatible models (#9712)
* add first draft.

* fix

* updates.

* updates.

* updates

* updates

* updates.

* fix-copies

* lora constants.

* add tests

* Apply suggestions from code review

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>

* docstrings.

---------

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
2024-11-02 09:50:39 +05:30

40 lines
1.5 KiB
Python

import os
import tempfile
import unittest
import torch
from diffusers.loaders.lora_base import LoraBaseMixin
class UtilityMethodDeprecationTests(unittest.TestCase):
def test_fetch_state_dict_cls_method_raises_warning(self):
state_dict = torch.nn.Linear(3, 3).state_dict()
with self.assertWarns(FutureWarning) as warning:
_ = LoraBaseMixin._fetch_state_dict(
state_dict,
weight_name=None,
use_safetensors=False,
local_files_only=True,
cache_dir=None,
force_download=False,
proxies=None,
token=None,
revision=None,
subfolder=None,
user_agent=None,
allow_pickle=None,
)
warning_message = str(warning.warnings[0].message)
assert "Using the `_fetch_state_dict()` method from" in warning_message
def test_best_guess_weight_name_cls_method_raises_warning(self):
with tempfile.TemporaryDirectory() as tmpdir:
state_dict = torch.nn.Linear(3, 3).state_dict()
torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin"))
with self.assertWarns(FutureWarning) as warning:
_ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir)
warning_message = str(warning.warnings[0].message)
assert "Using the `_best_guess_weight_name()` method from" in warning_message