mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
431 lines
18 KiB
Python
431 lines
18 KiB
Python
import io
|
|
import base64
|
|
import os
|
|
import re
|
|
import time
|
|
import json
|
|
import collections
|
|
from PIL import Image
|
|
from modules import shared, paths, modelloader, hashes, sd_hijack_accelerate
|
|
|
|
|
|
checkpoints_list = {}
|
|
checkpoint_aliases = {}
|
|
checkpoints_loaded = collections.OrderedDict()
|
|
model_dir = "Stable-diffusion"
|
|
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
|
|
sd_metadata_file = os.path.join(paths.data_path, "metadata.json")
|
|
sd_metadata = None
|
|
sd_metadata_pending = 0
|
|
sd_metadata_timer = 0
|
|
warn_once = False
|
|
|
|
|
|
class CheckpointInfo:
|
|
def __init__(self, filename, sha=None, subfolder=None):
|
|
self.name = None
|
|
self.hash = sha
|
|
self.filename = filename
|
|
self.type = ''
|
|
self.subfolder = subfolder
|
|
relname = filename
|
|
app_path = os.path.abspath(paths.script_path)
|
|
|
|
def rel(fn, path):
|
|
try:
|
|
return os.path.relpath(fn, path)
|
|
except Exception:
|
|
return fn
|
|
|
|
if relname.startswith('..'):
|
|
relname = os.path.abspath(relname)
|
|
if relname.startswith(shared.opts.ckpt_dir):
|
|
relname = rel(filename, shared.opts.ckpt_dir)
|
|
elif relname.startswith(shared.opts.diffusers_dir):
|
|
relname = rel(filename, shared.opts.diffusers_dir)
|
|
elif relname.startswith(model_path):
|
|
relname = rel(filename, model_path)
|
|
elif relname.startswith(paths.script_path):
|
|
relname = rel(filename, paths.script_path)
|
|
elif relname.startswith(app_path):
|
|
relname = rel(filename, app_path)
|
|
else:
|
|
relname = os.path.abspath(relname)
|
|
relname, ext = os.path.splitext(relname)
|
|
ext = ext.lower()[1:]
|
|
|
|
if filename.lower() == 'none':
|
|
self.name = 'none'
|
|
self.relname = 'none'
|
|
self.sha256 = None
|
|
self.type = 'unknown'
|
|
elif os.path.isfile(filename): # ckpt or safetensor
|
|
self.name = relname
|
|
self.filename = filename
|
|
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{relname}")
|
|
self.type = ext
|
|
if 'nf4' in filename:
|
|
self.type = 'transformer'
|
|
else: # maybe a diffuser
|
|
if self.hash is None:
|
|
repo = [r for r in modelloader.diffuser_repos if self.filename == r['name']]
|
|
else:
|
|
repo = [r for r in modelloader.diffuser_repos if self.hash == r['hash']]
|
|
if len(repo) == 0:
|
|
self.name = filename
|
|
self.filename = filename
|
|
self.sha256 = None
|
|
self.type = 'unknown'
|
|
else:
|
|
self.name = os.path.join(os.path.basename(shared.opts.diffusers_dir), repo[0]['name'])
|
|
self.filename = repo[0]['path']
|
|
self.sha256 = repo[0]['hash']
|
|
self.type = 'diffusers'
|
|
|
|
self.shorthash = self.sha256[0:10] if self.sha256 else None
|
|
self.title = self.name if self.shorthash is None else f'{self.name} [{self.shorthash}]'
|
|
self.path = self.filename
|
|
self.model_name = os.path.basename(self.name)
|
|
self.metadata = read_metadata_from_safetensors(filename)
|
|
# shared.log.debug(f'Checkpoint: type={self.type} name={self.name} filename={self.filename} hash={self.shorthash} title={self.title}')
|
|
|
|
def register(self):
|
|
checkpoints_list[self.title] = self
|
|
for i in [self.name, self.filename, self.shorthash, self.title]:
|
|
if i is not None:
|
|
checkpoint_aliases[i] = self
|
|
|
|
def calculate_shorthash(self):
|
|
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
|
|
if self.sha256 is None:
|
|
return None
|
|
self.shorthash = self.sha256[0:10]
|
|
if self.title in checkpoints_list:
|
|
checkpoints_list.pop(self.title)
|
|
self.title = f'{self.name} [{self.shorthash}]'
|
|
self.register()
|
|
return self.shorthash
|
|
|
|
def __str__(self):
|
|
return f'CheckpointInfo(name="{self.name}" filename="{self.filename}" hash={self.shorthash} type={self.type} title="{self.title}" path="{self.path}" subfolder="{self.subfolder}")'
|
|
|
|
|
|
def setup_model():
|
|
list_models()
|
|
sd_hijack_accelerate.hijack_hfhub()
|
|
# sd_hijack_accelerate.hijack_torch_conv()
|
|
|
|
|
|
def checkpoint_titles(use_short=False):
|
|
def convert(name):
|
|
return int(name) if name.isdigit() else name.lower()
|
|
|
|
def alphanumeric_key(key):
|
|
return [convert(c) for c in re.split("([0-9]+)", key)]
|
|
|
|
if use_short:
|
|
return sorted([x.title.rsplit("\\", 1)[-1].rsplit("/", 1)[-1] for x in checkpoints_list.values()], key=alphanumeric_key)
|
|
return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
|
|
|
|
|
|
def list_models():
|
|
t0 = time.time()
|
|
global checkpoints_list # pylint: disable=global-statement
|
|
checkpoints_list.clear()
|
|
checkpoint_aliases.clear()
|
|
ext_filter = [".safetensors"]
|
|
model_list = list(modelloader.load_models(model_path=model_path, model_url=None, command_path=shared.opts.ckpt_dir, ext_filter=ext_filter, download_name=None, ext_blacklist=[".vae.ckpt", ".vae.safetensors"]))
|
|
safetensors_list = []
|
|
for filename in sorted(model_list, key=str.lower):
|
|
checkpoint_info = CheckpointInfo(filename)
|
|
safetensors_list.append(checkpoint_info)
|
|
if checkpoint_info.name is not None:
|
|
checkpoint_info.register()
|
|
diffusers_list = []
|
|
for repo in modelloader.load_diffusers_models(clear=True):
|
|
checkpoint_info = CheckpointInfo(repo['name'], sha=repo['hash'])
|
|
diffusers_list.append(checkpoint_info)
|
|
if checkpoint_info.name is not None:
|
|
checkpoint_info.register()
|
|
if shared.cmd_opts.ckpt is not None:
|
|
checkpoint_info = CheckpointInfo(shared.cmd_opts.ckpt)
|
|
if checkpoint_info.name is not None:
|
|
checkpoint_info.register()
|
|
shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
|
|
elif shared.cmd_opts.ckpt != shared.default_sd_model_file and shared.cmd_opts.ckpt is not None:
|
|
shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found')
|
|
shared.log.info(f'Available Models: safetensors="{shared.opts.ckpt_dir}":{len(safetensors_list)} diffusers="{shared.opts.diffusers_dir}":{len(diffusers_list)} reference={len(list(shared.reference_models))} items={len(checkpoints_list)} time={time.time()-t0:.2f}')
|
|
checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename))
|
|
|
|
|
|
def update_model_hashes():
|
|
def update_model_hashes_table(rows):
|
|
html = """
|
|
<table class="simple-table">
|
|
<thead>
|
|
<tr><th>Name</th><th>Type</th><th>Hash</th></tr>
|
|
</thead>
|
|
<tbody>
|
|
{tbody}
|
|
</tbody>
|
|
</table>
|
|
"""
|
|
tbody = ''
|
|
for row in rows:
|
|
try:
|
|
tbody += f"""
|
|
<tr>
|
|
<td>{row.name}</td>
|
|
<td>{row.type}</td>
|
|
<td>{row.shorthash}</td>
|
|
</tr>
|
|
"""
|
|
except Exception as e:
|
|
shared.log.error(f'Model list: row={row} {e}')
|
|
return html.format(tbody=tbody)
|
|
|
|
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None]
|
|
for ckpt in lst:
|
|
ckpt.hash = model_hash(ckpt.filename)
|
|
lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None]
|
|
shared.log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}')
|
|
updated = []
|
|
for ckpt in lst:
|
|
ckpt.sha256 = hashes.sha256(ckpt.filename, f"checkpoint/{ckpt.name}")
|
|
ckpt.shorthash = ckpt.sha256[0:10] if ckpt.sha256 is not None else None
|
|
updated.append(ckpt)
|
|
yield update_model_hashes_table(updated)
|
|
|
|
|
|
def remove_hash(s):
|
|
return re.sub(r'\s*\[.*?\]', '', s)
|
|
|
|
|
|
def get_closest_checkpoint_match(s: str) -> CheckpointInfo:
|
|
# direct hf url
|
|
if s.startswith('https://huggingface.co/'):
|
|
model_name = s.replace('https://huggingface.co/', '')
|
|
checkpoint_info = CheckpointInfo(model_name) # create a virutal model info
|
|
checkpoint_info.type = 'huggingface'
|
|
shared.log.debug(f'Seach model: name="{s}" matched="{checkpoint_info.path}" type=huggingface')
|
|
return checkpoint_info
|
|
if s.startswith('huggingface/'):
|
|
model_name = s.replace('huggingface/', '')
|
|
checkpoint_info = CheckpointInfo(model_name) # create a virutal model info
|
|
checkpoint_info.type = 'huggingface'
|
|
return checkpoint_info
|
|
|
|
# alias search
|
|
checkpoint_info = checkpoint_aliases.get(s, None)
|
|
if checkpoint_info is not None:
|
|
shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=alias')
|
|
return checkpoint_info
|
|
|
|
# models search
|
|
found = sorted([info for info in checkpoints_list.values() if os.path.basename(info.title).lower() == s.lower()], key=lambda x: len(x.title))
|
|
if found and len(found) == 1:
|
|
checkpoint_info = found[0]
|
|
shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=hash')
|
|
return checkpoint_info
|
|
|
|
# nohash search
|
|
found = sorted([info for info in checkpoints_list.values() if remove_hash(info.title).lower() == remove_hash(s).lower()], key=lambda x: len(x.title))
|
|
if found and len(found) == 1:
|
|
checkpoint_info = found[0]
|
|
shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=model')
|
|
return checkpoint_info
|
|
|
|
# absolute path
|
|
if s.endswith('.safetensors') and os.path.isfile(s):
|
|
checkpoint_info = CheckpointInfo(s)
|
|
checkpoint_info.type = 'safetensors'
|
|
shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=safetensors')
|
|
return checkpoint_info
|
|
|
|
# reference search
|
|
ref = [(k, v) for k, v in shared.reference_models.items() if f"{v.get('path', '')}+{v.get('subfolder', '')}" == s]
|
|
if len(ref) == 0:
|
|
ref = [(k, v) for k, v in shared.reference_models.items() if v.get('path', '') == s]
|
|
if ref and len(ref) > 0:
|
|
_name, info = ref[0]
|
|
checkpoint_info = CheckpointInfo(s)
|
|
checkpoint_info.subfolder = info.get('subfolder', None)
|
|
checkpoint_info.type = 'reference'
|
|
shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=reference')
|
|
return checkpoint_info
|
|
|
|
# huggingface search
|
|
if shared.opts.sd_checkpoint_autodownload and (s.count('/') == 1 or s.count('/') == 2):
|
|
if s.count('/') == 2:
|
|
subfolder = '/'.join(s.split('/')[2:]) # subfolder
|
|
s = '/'.join(s.split('/')[:2]) # only user/model
|
|
else:
|
|
subfolder = None
|
|
modelloader.hf_login()
|
|
found = modelloader.find_diffuser(s, full=True)
|
|
if found is None:
|
|
return None
|
|
found = [f for f in found if f == s]
|
|
shared.log.info(f'HF search: model="{s}" results={found}')
|
|
if found is not None and len(found) == 1:
|
|
checkpoint_info = CheckpointInfo(s)
|
|
checkpoint_info.type = 'huggingface'
|
|
if subfolder is not None and len(subfolder) > 0:
|
|
checkpoint_info.subfolder = subfolder
|
|
shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=huggingface')
|
|
return checkpoint_info
|
|
|
|
# civitai search
|
|
if shared.opts.sd_checkpoint_autodownload and s.startswith("https://civitai.com/api/download/models"):
|
|
from modules.civitai.download_civitai import download_civit_model_thread
|
|
fn = download_civit_model_thread(model_name=None, model_url=s, model_path='', model_type='Model', token=None)
|
|
if fn is not None:
|
|
checkpoint_info = CheckpointInfo(fn)
|
|
shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=civitai')
|
|
return checkpoint_info
|
|
|
|
return None
|
|
|
|
|
|
def model_hash(filename):
|
|
"""old hash that only looks at a small part of the file and is prone to collisions"""
|
|
try:
|
|
with open(filename, "rb") as file:
|
|
import hashlib
|
|
m = hashlib.sha256()
|
|
file.seek(0x100000)
|
|
m.update(file.read(0x10000))
|
|
shorthash = m.hexdigest()[0:8]
|
|
return shorthash
|
|
except FileNotFoundError:
|
|
return 'NOFILE'
|
|
except Exception:
|
|
return 'NOHASH'
|
|
|
|
|
|
def select_checkpoint(op='model', sd_model_checkpoint=None):
|
|
model_checkpoint = sd_model_checkpoint or (shared.opts.data.get('sd_model_refiner', None) if op == 'refiner' else shared.opts.data.get('sd_model_checkpoint', None))
|
|
if model_checkpoint is None or model_checkpoint == 'None' or len(model_checkpoint) < 3:
|
|
return None
|
|
checkpoint_info = get_closest_checkpoint_match(model_checkpoint)
|
|
if checkpoint_info is not None:
|
|
shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"')
|
|
return checkpoint_info
|
|
if len(checkpoints_list) == 0:
|
|
shared.log.error("No models found")
|
|
global warn_once # pylint: disable=global-statement
|
|
if not warn_once:
|
|
warn_once = True
|
|
shared.log.info("Set system paths to use existing folders")
|
|
shared.log.info(" or use --models-dir <path-to-folder> to specify base folder with all models")
|
|
shared.log.info(" or use --ckpt <path-to-checkpoint> to force using specific model")
|
|
return None
|
|
if model_checkpoint is not None:
|
|
if model_checkpoint != 'model.safetensors' and model_checkpoint != 'stabilityai/stable-diffusion-xl-base-1.0':
|
|
shared.log.error(f'Load {op}: search="{model_checkpoint}" not found')
|
|
else:
|
|
shared.log.info("Selecting first available checkpoint")
|
|
else:
|
|
shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"')
|
|
return checkpoint_info
|
|
|
|
|
|
def init_metadata():
|
|
global sd_metadata # pylint: disable=global-statement
|
|
if sd_metadata is None:
|
|
sd_metadata = shared.readfile(sd_metadata_file, lock=True, as_type="dict") if os.path.isfile(sd_metadata_file) else {}
|
|
|
|
|
|
def extract_thumbnail(filename, data):
|
|
try:
|
|
thumbnail = data.split(",")[1]
|
|
thumbnail = base64.b64decode(thumbnail)
|
|
thumbnail = io.BytesIO(thumbnail)
|
|
thumbnail = Image.open(thumbnail)
|
|
thumbnail = thumbnail.convert("RGB")
|
|
thumbnail = thumbnail.resize((512, 512), Image.Resampling.HAMMING)
|
|
fn = os.path.splitext(filename)[0]
|
|
thumbnail = thumbnail.save(f"{fn}.thumb.jpg", quality=50)
|
|
except Exception as e:
|
|
shared.log.error(f"Error extracting thumbnail: {filename} {e}")
|
|
|
|
|
|
def read_metadata_from_safetensors(filename):
|
|
global sd_metadata # pylint: disable=global-statement
|
|
if sd_metadata is None:
|
|
sd_metadata = shared.readfile(sd_metadata_file, lock=True, as_type="dict") if os.path.isfile(sd_metadata_file) else {}
|
|
res = sd_metadata.get(filename, None)
|
|
if res is not None:
|
|
return res
|
|
if not filename.endswith(".safetensors"):
|
|
return {}
|
|
if shared.cmd_opts.no_metadata:
|
|
return {}
|
|
res = {}
|
|
# try:
|
|
t0 = time.time()
|
|
with open(filename, mode="rb") as file:
|
|
try:
|
|
metadata_len = file.read(8)
|
|
metadata_len = int.from_bytes(metadata_len, "little")
|
|
json_start = file.read(2)
|
|
if metadata_len <= 2 or json_start not in (b'{"', b"{'"):
|
|
shared.log.error(f'Model metadata invalid: file="{filename}" len={metadata_len} start={json_start}')
|
|
return res
|
|
json_data = json_start + file.read(metadata_len-2)
|
|
json_obj = json.loads(json_data)
|
|
for k, v in json_obj.get("__metadata__", {}).items():
|
|
if k == 'modelspec.thumbnail' and v.startswith("data:"):
|
|
extract_thumbnail(filename, v)
|
|
if v.startswith("data:"):
|
|
v = 'data'
|
|
if k == 'format' and v == 'pt':
|
|
continue
|
|
large = True if len(v) > 2048 else False
|
|
if large and k in ['ss_datasets', 'workflow', 'prompt', 'ss_bucket_info', 'sd_metadata_file']:
|
|
continue
|
|
if v[0:1] == '{':
|
|
try:
|
|
v = json.loads(v)
|
|
if large and k == 'ss_tag_frequency':
|
|
v = { i: len(j) for i, j in v.items() }
|
|
if large and k == 'sd_merge_models':
|
|
scrub_dict(v, ['sd_merge_recipe'])
|
|
except Exception:
|
|
pass
|
|
res[k] = v
|
|
except Exception as e:
|
|
shared.log.error(f'Model metadata: file="{filename}" {e}')
|
|
from modules import errors
|
|
errors.display(e, 'Model metadata')
|
|
sd_metadata[filename] = res
|
|
global sd_metadata_pending # pylint: disable=global-statement
|
|
sd_metadata_pending += 1
|
|
t1 = time.time()
|
|
global sd_metadata_timer # pylint: disable=global-statement
|
|
sd_metadata_timer += (t1 - t0)
|
|
return res
|
|
|
|
|
|
def scrub_dict(dict_obj, keys):
|
|
for key in list(dict_obj.keys()):
|
|
if not isinstance(dict_obj, dict):
|
|
continue
|
|
if key in keys:
|
|
dict_obj.pop(key, None)
|
|
elif isinstance(dict_obj[key], dict):
|
|
scrub_dict(dict_obj[key], keys)
|
|
elif isinstance(dict_obj[key], list):
|
|
for item in dict_obj[key]:
|
|
scrub_dict(item, keys)
|
|
|
|
|
|
def write_metadata():
|
|
global sd_metadata_pending # pylint: disable=global-statement
|
|
if sd_metadata_pending == 0:
|
|
shared.log.debug(f'Model metadata: file="{sd_metadata_file}" no changes')
|
|
return
|
|
shared.writefile(sd_metadata, sd_metadata_file)
|
|
shared.log.info(f'Model metadata saved: file="{sd_metadata_file}" items={sd_metadata_pending} time={sd_metadata_timer:.2f}')
|
|
sd_metadata_pending = 0
|