mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
* add first draft. * fix * updates. * updates. * updates * updates * updates. * fix-copies * lora constants. * add tests * Apply suggestions from code review Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> * docstrings. --------- Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
40 lines
1.5 KiB
Python
40 lines
1.5 KiB
Python
import os
|
|
import tempfile
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from diffusers.loaders.lora_base import LoraBaseMixin
|
|
|
|
|
|
class UtilityMethodDeprecationTests(unittest.TestCase):
|
|
def test_fetch_state_dict_cls_method_raises_warning(self):
|
|
state_dict = torch.nn.Linear(3, 3).state_dict()
|
|
with self.assertWarns(FutureWarning) as warning:
|
|
_ = LoraBaseMixin._fetch_state_dict(
|
|
state_dict,
|
|
weight_name=None,
|
|
use_safetensors=False,
|
|
local_files_only=True,
|
|
cache_dir=None,
|
|
force_download=False,
|
|
proxies=None,
|
|
token=None,
|
|
revision=None,
|
|
subfolder=None,
|
|
user_agent=None,
|
|
allow_pickle=None,
|
|
)
|
|
warning_message = str(warning.warnings[0].message)
|
|
assert "Using the `_fetch_state_dict()` method from" in warning_message
|
|
|
|
def test_best_guess_weight_name_cls_method_raises_warning(self):
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
state_dict = torch.nn.Linear(3, 3).state_dict()
|
|
torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin"))
|
|
|
|
with self.assertWarns(FutureWarning) as warning:
|
|
_ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir)
|
|
warning_message = str(warning.warnings[0].message)
|
|
assert "Using the `_best_guess_weight_name()` method from" in warning_message
|