1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] relax lora loading logic (#4610)

* relax lora loading logic.

* cater to the other cases too.

* fix: variable name

* bring the chaos down.

* check

* deal with checkpointed files.

* Apply suggestions from code review

Co-authored-by: apolinário <joaopaulo.passos@gmail.com>

* style

---------

Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
This commit is contained in:
Sayak Paul
2023-08-25 09:35:51 +05:30
committed by GitHub
parent c25c46137d
commit 5222294748

View File

@@ -25,7 +25,7 @@ import requests
import safetensors
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from huggingface_hub import hf_hub_download, model_info
from torch import nn
from .utils import (
@@ -1021,6 +1021,13 @@ class LoraLoaderMixin:
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
# Here we're relaxing the loading check to enable more Inference API
# friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`.
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
@@ -1041,7 +1048,12 @@ class LoraLoaderMixin:
# try loading non-safetensors weights
model_file = None
pass
if model_file is None:
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin"
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME,
@@ -1077,6 +1089,31 @@ class LoraLoaderMixin:
return state_dict, network_alphas
@classmethod
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict):
return
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
targeted_files = [
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
]
else:
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
if len(targeted_files) == 0:
return
targeted_files = list(filter(lambda x: "scheduler" not in x and "optimizer" not in x, targeted_files))
if len(targeted_files) > 1:
raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
)
weight_name = targeted_files[0]
return weight_name
@classmethod
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)