1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/ui_extra_networks_checkpoints.py
Vladimir Mandic 30da7803b5 futureproof
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2026-01-15 09:29:26 +00:00

168 lines
7.0 KiB
Python

import os
import html
import json
import concurrent
from datetime import datetime
from modules import shared, ui_extra_networks, sd_models, modelstats, paths
from modules.json_helpers import readfile
version_map = {
"QwenEdit": "Qwen",
"QwenEditPlus": "Qwen",
"Flux.1 D": "Flux",
"Flux.1 S": "Flux",
"FluxKontext": "Flux",
"SDXL 1.0": "SD XL",
"SDXL Hyper": "SD XL",
"StableDiffusion3": "SD 3",
"StableDiffusionXL": "SD XL",
"WanToVideo": "Wan",
"WanVACE": "Wan",
"Z": "Z-Image",
"Glm": "GLM-Image",
}
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def __init__(self):
super().__init__('Model')
def refresh(self):
shared.refresh_checkpoints()
def list_reference(self): # pylint: disable=inconsistent-return-statements
existing = [model.filename if model.type == 'safetensors' else model.name for model in sd_models.checkpoints_list.values()]
def reference_downloaded(url):
url = url.split('@')[0] if '@' in url else 'Diffusers/' + url
url = url.split('+')[0] if '+' in url else url
return any(model.endswith(url) for model in existing)
if not shared.opts.sd_checkpoint_autodownload or not shared.opts.extra_network_reference_enable:
shared.log.debug(f'Networks: type="reference" autodownload={shared.opts.sd_checkpoint_autodownload} enable={shared.opts.extra_network_reference_enable}')
return []
count = { 'total': 0, 'ready': 0, 'hidden': 0, 'experimental': 0, 'base': 0 }
reference_base = readfile(os.path.join('html', 'reference.json'), as_type="dict")
reference_quant = readfile(os.path.join('html', 'reference-quant.json'), as_type="dict")
reference_distilled = readfile(os.path.join('html', 'reference-distilled.json'), as_type="dict")
reference_community = readfile(os.path.join('html', 'reference-community.json'), as_type="dict")
reference_cloud = readfile(os.path.join('html', 'reference-cloud.json'), as_type="dict")
shared.reference_models = {}
shared.reference_models.update(reference_base)
shared.reference_models.update(reference_quant)
shared.reference_models.update(reference_community)
shared.reference_models.update(reference_distilled)
shared.reference_models.update(reference_cloud)
for k, v in shared.reference_models.items():
count['total'] += 1
url = v['path']
experimental = v.get('experimental', False)
if experimental:
if shared.cmd_opts.experimental:
shared.log.debug(f'Networks: experimental model="{k}"')
count['experimental'] += 1
else:
continue
preview = v.get('preview', v['path'])
preview_file = self.find_preview_file(os.path.join(paths.reference_path, preview))
name = os.path.normpath(os.path.join(paths.reference_path, k)).replace('\\', '/')
size = int(float(v.get('size', 0)) * 1024 * 1024 * 1024)
mtime = v.get('date', None)
if mtime is None:
_size, mtime = modelstats.stat(preview_file)
else:
try:
mtime = datetime.strptime(mtime, '%Y %B') # 2025 January
except Exception:
_size, mtime = modelstats.stat(preview_file)
if len(v.get("subfolder", "")) > 0:
path = f'{v.get("path", "")}+{v.get("subfolder", "")}'
else:
path = f'{v.get("path", "")}'
tag = v.get('tags', '')
if tag in count:
count[tag] += 1
elif tag != '':
count[tag] = 1
else:
count['base'] += 1
ready = reference_downloaded(url)
version = "ready" if ready else "download"
if tag == 'cloud':
version = 'Cloud'
if not ready and shared.opts.offline_mode:
count['hidden'] += 1
continue
if ready:
count['ready'] += 1
yield {
"type": 'Model',
"name": name,
"title": name,
"filename": url,
"preview": self.find_preview(os.path.join(paths.reference_path, preview)),
"local_preview": preview_file,
"onclick": '"' + html.escape(f"selectReference({json.dumps(path)})") + '"',
"hash": None,
"mtime": mtime,
"size": size,
"info": {},
"metadata": {},
"description": v.get('desc', ''),
"version": version,
"tags": tag,
}
shared.log.debug(f'Networks: type="reference" {count}')
def create_item(self, name):
record = None
try:
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoints_list.get(name)
size, mtime = modelstats.stat(checkpoint.filename)
record = {
"type": 'Model',
"name": checkpoint.name,
"title": checkpoint.title,
"filename": checkpoint.filename,
"hash": checkpoint.shorthash,
"metadata": checkpoint.metadata,
"onclick": '"' + html.escape(f"selectCheckpoint({json.dumps(name)})") + '"',
"mtime": mtime,
"size": size,
}
record['info'] = self.find_info(checkpoint.filename)
record['description'] = self.find_description(checkpoint.filename, record['info'])
version = self.find_version(checkpoint, record['info'])
if 'baseModel' in version:
record['version'] = version.get("baseModel", "")
elif '_class_name' in record['info']:
record['version'] = record['info'].get('_class_name', '').replace('Pipeline', '').replace('Image', '')
else:
record['version'] = ''
record['version'] = version_map.get(record['version'], record['version'])
except Exception as e:
shared.log.debug(f'Networks error: type=model file="{name}" {e}')
return record
def list_items(self):
items = []
with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
future_items = {executor.submit(self.create_item, cp): cp for cp in list(sd_models.checkpoints_list.copy())}
for future in concurrent.futures.as_completed(future_items):
item = future.result()
if item is not None:
items.append(item)
for record in self.list_reference():
items.append(record)
self.update_all_previews(items)
return items
def allowed_directories_for_previews(self):
return [v for v in [shared.opts.ckpt_dir, paths.reference_path, sd_models.model_path] if v is not None]