import io import base64 import os import re import time import json import collections from PIL import Image from modules import shared, paths, modelloader, hashes, sd_hijack_accelerate checkpoints_list = {} checkpoint_aliases = {} checkpoints_loaded = collections.OrderedDict() model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) sd_metadata_file = os.path.join(paths.data_path, "metadata.json") sd_metadata = None sd_metadata_pending = 0 sd_metadata_timer = 0 warn_once = False class CheckpointInfo: def __init__(self, filename, sha=None, subfolder=None): self.name = None self.hash = sha self.filename = filename self.type = '' self.subfolder = subfolder relname = filename app_path = os.path.abspath(paths.script_path) def rel(fn, path): try: return os.path.relpath(fn, path) except Exception: return fn if relname.startswith('..'): relname = os.path.abspath(relname) if relname.startswith(shared.opts.ckpt_dir): relname = rel(filename, shared.opts.ckpt_dir) elif relname.startswith(shared.opts.diffusers_dir): relname = rel(filename, shared.opts.diffusers_dir) elif relname.startswith(model_path): relname = rel(filename, model_path) elif relname.startswith(paths.script_path): relname = rel(filename, paths.script_path) elif relname.startswith(app_path): relname = rel(filename, app_path) else: relname = os.path.abspath(relname) relname, ext = os.path.splitext(relname) ext = ext.lower()[1:] if filename.lower() == 'none': self.name = 'none' self.relname = 'none' self.sha256 = None self.type = 'unknown' elif os.path.isfile(filename): # ckpt or safetensor self.name = relname self.filename = filename self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{relname}") self.type = ext if 'nf4' in filename: self.type = 'transformer' else: # maybe a diffuser if self.hash is None: repo = [r for r in modelloader.diffuser_repos if self.filename == r['name']] else: repo = [r for r in modelloader.diffuser_repos if self.hash == r['hash']] if len(repo) == 0: self.name = filename self.filename = filename self.sha256 = None self.type = 'unknown' else: self.name = os.path.join(os.path.basename(shared.opts.diffusers_dir), repo[0]['name']) self.filename = repo[0]['path'] self.sha256 = repo[0]['hash'] self.type = 'diffusers' self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = self.name if self.shorthash is None else f'{self.name} [{self.shorthash}]' self.path = self.filename self.model_name = os.path.basename(self.name) self.metadata = read_metadata_from_safetensors(filename) # shared.log.debug(f'Checkpoint: type={self.type} name={self.name} filename={self.filename} hash={self.shorthash} title={self.title}') def register(self): checkpoints_list[self.title] = self for i in [self.name, self.filename, self.shorthash, self.title]: if i is not None: checkpoint_aliases[i] = self def calculate_shorthash(self): self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}") if self.sha256 is None: return None self.shorthash = self.sha256[0:10] if self.title in checkpoints_list: checkpoints_list.pop(self.title) self.title = f'{self.name} [{self.shorthash}]' self.register() return self.shorthash def __str__(self): return f'CheckpointInfo(name="{self.name}" filename="{self.filename}" hash={self.shorthash} type={self.type} title="{self.title}" path="{self.path}" subfolder="{self.subfolder}")' def setup_model(): list_models() sd_hijack_accelerate.hijack_hfhub() # sd_hijack_accelerate.hijack_torch_conv() def checkpoint_titles(use_short=False): def convert(name): return int(name) if name.isdigit() else name.lower() def alphanumeric_key(key): return [convert(c) for c in re.split("([0-9]+)", key)] if use_short: return sorted([x.title.rsplit("\\", 1)[-1].rsplit("/", 1)[-1] for x in checkpoints_list.values()], key=alphanumeric_key) return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) def list_models(): t0 = time.time() global checkpoints_list # pylint: disable=global-statement checkpoints_list.clear() checkpoint_aliases.clear() ext_filter = [".safetensors"] model_list = list(modelloader.load_models(model_path=model_path, model_url=None, command_path=shared.opts.ckpt_dir, ext_filter=ext_filter, download_name=None, ext_blacklist=[".vae.ckpt", ".vae.safetensors"])) safetensors_list = [] for filename in sorted(model_list, key=str.lower): checkpoint_info = CheckpointInfo(filename) safetensors_list.append(checkpoint_info) if checkpoint_info.name is not None: checkpoint_info.register() diffusers_list = [] for repo in modelloader.load_diffusers_models(clear=True): checkpoint_info = CheckpointInfo(repo['name'], sha=repo['hash']) diffusers_list.append(checkpoint_info) if checkpoint_info.name is not None: checkpoint_info.register() if shared.cmd_opts.ckpt is not None: checkpoint_info = CheckpointInfo(shared.cmd_opts.ckpt) if checkpoint_info.name is not None: checkpoint_info.register() shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title elif shared.cmd_opts.ckpt != shared.default_sd_model_file and shared.cmd_opts.ckpt is not None: shared.log.warning(f'Load model: path="{shared.cmd_opts.ckpt}" not found') shared.log.info(f'Available Models: safetensors="{shared.opts.ckpt_dir}":{len(safetensors_list)} diffusers="{shared.opts.diffusers_dir}":{len(diffusers_list)} reference={len(list(shared.reference_models))} items={len(checkpoints_list)} time={time.time()-t0:.2f}') checkpoints_list = dict(sorted(checkpoints_list.items(), key=lambda cp: cp[1].filename)) def update_model_hashes(): def update_model_hashes_table(rows): html = """ {tbody}
NameTypeHash
""" tbody = '' for row in rows: try: tbody += f""" {row.name} {row.type} {row.shorthash} """ except Exception as e: shared.log.error(f'Model list: row={row} {e}') return html.format(tbody=tbody) lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.hash is None] for ckpt in lst: ckpt.hash = model_hash(ckpt.filename) lst = [ckpt for ckpt in checkpoints_list.values() if ckpt.sha256 is None or ckpt.shorthash is None] shared.log.info(f'Models list: hash missing={len(lst)} total={len(checkpoints_list)}') updated = [] for ckpt in lst: ckpt.sha256 = hashes.sha256(ckpt.filename, f"checkpoint/{ckpt.name}") ckpt.shorthash = ckpt.sha256[0:10] if ckpt.sha256 is not None else None updated.append(ckpt) yield update_model_hashes_table(updated) def remove_hash(s): return re.sub(r'\s*\[.*?\]', '', s) def get_closest_checkpoint_match(s: str) -> CheckpointInfo: # direct hf url if s.startswith('https://huggingface.co/'): model_name = s.replace('https://huggingface.co/', '') checkpoint_info = CheckpointInfo(model_name) # create a virutal model info checkpoint_info.type = 'huggingface' shared.log.debug(f'Seach model: name="{s}" matched="{checkpoint_info.path}" type=huggingface') return checkpoint_info if s.startswith('huggingface/'): model_name = s.replace('huggingface/', '') checkpoint_info = CheckpointInfo(model_name) # create a virutal model info checkpoint_info.type = 'huggingface' return checkpoint_info # alias search checkpoint_info = checkpoint_aliases.get(s, None) if checkpoint_info is not None: shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=alias') return checkpoint_info # models search found = sorted([info for info in checkpoints_list.values() if os.path.basename(info.title).lower() == s.lower()], key=lambda x: len(x.title)) if found and len(found) == 1: checkpoint_info = found[0] shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=hash') return checkpoint_info # nohash search found = sorted([info for info in checkpoints_list.values() if remove_hash(info.title).lower() == remove_hash(s).lower()], key=lambda x: len(x.title)) if found and len(found) == 1: checkpoint_info = found[0] shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=model') return checkpoint_info # absolute path if s.endswith('.safetensors') and os.path.isfile(s): checkpoint_info = CheckpointInfo(s) checkpoint_info.type = 'safetensors' shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=safetensors') return checkpoint_info # reference search ref = [(k, v) for k, v in shared.reference_models.items() if f"{v.get('path', '')}+{v.get('subfolder', '')}" == s] if len(ref) == 0: ref = [(k, v) for k, v in shared.reference_models.items() if v.get('path', '') == s] if ref and len(ref) > 0: _name, info = ref[0] checkpoint_info = CheckpointInfo(s) checkpoint_info.subfolder = info.get('subfolder', None) checkpoint_info.type = 'reference' shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=reference') return checkpoint_info # huggingface search if shared.opts.sd_checkpoint_autodownload and (s.count('/') == 1 or s.count('/') == 2): if s.count('/') == 2: subfolder = '/'.join(s.split('/')[2:]) # subfolder s = '/'.join(s.split('/')[:2]) # only user/model else: subfolder = None modelloader.hf_login() found = modelloader.find_diffuser(s, full=True) if found is None: return None found = [f for f in found if f == s] shared.log.info(f'HF search: model="{s}" results={found}') if found is not None and len(found) == 1: checkpoint_info = CheckpointInfo(s) checkpoint_info.type = 'huggingface' if subfolder is not None and len(subfolder) > 0: checkpoint_info.subfolder = subfolder shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=huggingface') return checkpoint_info # civitai search if shared.opts.sd_checkpoint_autodownload and s.startswith("https://civitai.com/api/download/models"): from modules.civitai.download_civitai import download_civit_model_thread fn = download_civit_model_thread(model_name=None, model_url=s, model_path='', model_type='Model', token=None) if fn is not None: checkpoint_info = CheckpointInfo(fn) shared.log.debug(f'Search model: name="{s}" matched="{checkpoint_info.path}" type=civitai') return checkpoint_info return None def model_hash(filename): """old hash that only looks at a small part of the file and is prone to collisions""" try: with open(filename, "rb") as file: import hashlib m = hashlib.sha256() file.seek(0x100000) m.update(file.read(0x10000)) shorthash = m.hexdigest()[0:8] return shorthash except FileNotFoundError: return 'NOFILE' except Exception: return 'NOHASH' def select_checkpoint(op='model', sd_model_checkpoint=None): model_checkpoint = sd_model_checkpoint or (shared.opts.data.get('sd_model_refiner', None) if op == 'refiner' else shared.opts.data.get('sd_model_checkpoint', None)) if model_checkpoint is None or model_checkpoint == 'None' or len(model_checkpoint) < 3: return None checkpoint_info = get_closest_checkpoint_match(model_checkpoint) if checkpoint_info is not None: shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"') return checkpoint_info if len(checkpoints_list) == 0: shared.log.error("No models found") global warn_once # pylint: disable=global-statement if not warn_once: warn_once = True shared.log.info("Set system paths to use existing folders") shared.log.info(" or use --models-dir to specify base folder with all models") shared.log.info(" or use --ckpt to force using specific model") return None if model_checkpoint is not None: if model_checkpoint != 'model.safetensors' and model_checkpoint != 'stabilityai/stable-diffusion-xl-base-1.0': shared.log.error(f'Load {op}: search="{model_checkpoint}" not found') else: shared.log.info("Selecting first available checkpoint") else: shared.log.info(f'Load {op}: select="{checkpoint_info.title if checkpoint_info is not None else None}"') return checkpoint_info def init_metadata(): global sd_metadata # pylint: disable=global-statement if sd_metadata is None: sd_metadata = shared.readfile(sd_metadata_file, lock=True, as_type="dict") if os.path.isfile(sd_metadata_file) else {} def extract_thumbnail(filename, data): try: thumbnail = data.split(",")[1] thumbnail = base64.b64decode(thumbnail) thumbnail = io.BytesIO(thumbnail) thumbnail = Image.open(thumbnail) thumbnail = thumbnail.convert("RGB") thumbnail = thumbnail.resize((512, 512), Image.Resampling.HAMMING) fn = os.path.splitext(filename)[0] thumbnail = thumbnail.save(f"{fn}.thumb.jpg", quality=50) except Exception as e: shared.log.error(f"Error extracting thumbnail: {filename} {e}") def read_metadata_from_safetensors(filename): global sd_metadata # pylint: disable=global-statement if sd_metadata is None: sd_metadata = shared.readfile(sd_metadata_file, lock=True, as_type="dict") if os.path.isfile(sd_metadata_file) else {} res = sd_metadata.get(filename, None) if res is not None: return res if not filename.endswith(".safetensors"): return {} if shared.cmd_opts.no_metadata: return {} res = {} # try: t0 = time.time() with open(filename, mode="rb") as file: try: metadata_len = file.read(8) metadata_len = int.from_bytes(metadata_len, "little") json_start = file.read(2) if metadata_len <= 2 or json_start not in (b'{"', b"{'"): shared.log.error(f'Model metadata invalid: file="{filename}" len={metadata_len} start={json_start}') return res json_data = json_start + file.read(metadata_len-2) json_obj = json.loads(json_data) for k, v in json_obj.get("__metadata__", {}).items(): if k == 'modelspec.thumbnail' and v.startswith("data:"): extract_thumbnail(filename, v) if v.startswith("data:"): v = 'data' if k == 'format' and v == 'pt': continue large = True if len(v) > 2048 else False if large and k in ['ss_datasets', 'workflow', 'prompt', 'ss_bucket_info', 'sd_metadata_file']: continue if v[0:1] == '{': try: v = json.loads(v) if large and k == 'ss_tag_frequency': v = { i: len(j) for i, j in v.items() } if large and k == 'sd_merge_models': scrub_dict(v, ['sd_merge_recipe']) except Exception: pass res[k] = v except Exception as e: shared.log.error(f'Model metadata: file="{filename}" {e}') from modules import errors errors.display(e, 'Model metadata') sd_metadata[filename] = res global sd_metadata_pending # pylint: disable=global-statement sd_metadata_pending += 1 t1 = time.time() global sd_metadata_timer # pylint: disable=global-statement sd_metadata_timer += (t1 - t0) return res def scrub_dict(dict_obj, keys): for key in list(dict_obj.keys()): if not isinstance(dict_obj, dict): continue if key in keys: dict_obj.pop(key, None) elif isinstance(dict_obj[key], dict): scrub_dict(dict_obj[key], keys) elif isinstance(dict_obj[key], list): for item in dict_obj[key]: scrub_dict(item, keys) def write_metadata(): global sd_metadata_pending # pylint: disable=global-statement if sd_metadata_pending == 0: shared.log.debug(f'Model metadata: file="{sd_metadata_file}" no changes') return shared.writefile(sd_metadata, sd_metadata_file) shared.log.info(f'Model metadata saved: file="{sd_metadata_file}" items={sd_metadata_pending} time={sd_metadata_timer:.2f}') sd_metadata_pending = 0