mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] relax lora loading logic (#4610)
* relax lora loading logic. * cater to the other cases too. * fix: variable name * bring the chaos down. * check * deal with checkpointed files. * Apply suggestions from code review Co-authored-by: apolinário <joaopaulo.passos@gmail.com> * style --------- Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
This commit is contained in:
@@ -25,7 +25,7 @@ import requests
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import hf_hub_download, model_info
|
||||
from torch import nn
|
||||
|
||||
from .utils import (
|
||||
@@ -1021,6 +1021,13 @@ class LoraLoaderMixin:
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
# Here we're relaxing the loading check to enable more Inference API
|
||||
# friendliness where sometimes, it's not at all possible to automatically
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors"
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
@@ -1041,7 +1048,12 @@ class LoraLoaderMixin:
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = cls._best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin"
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
@@ -1077,6 +1089,31 @@ class LoraLoaderMixin:
|
||||
|
||||
return state_dict, network_alphas
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors"):
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
return
|
||||
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
||||
targeted_files = [
|
||||
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
|
||||
]
|
||||
else:
|
||||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
||||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
||||
|
||||
if len(targeted_files) == 0:
|
||||
return
|
||||
|
||||
targeted_files = list(filter(lambda x: "scheduler" not in x and "optimizer" not in x, targeted_files))
|
||||
if len(targeted_files) > 1:
|
||||
raise ValueError(
|
||||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
||||
)
|
||||
weight_name = targeted_files[0]
|
||||
return weight_name
|
||||
|
||||
@classmethod
|
||||
def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
|
||||
is_all_unet = all(k.startswith("lora_unet") for k in state_dict)
|
||||
|
||||
Reference in New Issue
Block a user