mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
168 lines
7.0 KiB
Python
168 lines
7.0 KiB
Python
import os
|
|
import html
|
|
import json
|
|
import concurrent
|
|
from datetime import datetime
|
|
from modules import shared, ui_extra_networks, sd_models, modelstats, paths
|
|
from modules.json_helpers import readfile
|
|
|
|
|
|
version_map = {
|
|
"QwenEdit": "Qwen",
|
|
"QwenEditPlus": "Qwen",
|
|
"Flux.1 D": "Flux",
|
|
"Flux.1 S": "Flux",
|
|
"FluxKontext": "Flux",
|
|
"SDXL 1.0": "SD XL",
|
|
"SDXL Hyper": "SD XL",
|
|
"StableDiffusion3": "SD 3",
|
|
"StableDiffusionXL": "SD XL",
|
|
"WanToVideo": "Wan",
|
|
"WanVACE": "Wan",
|
|
"Z": "Z-Image",
|
|
"Glm": "GLM-Image",
|
|
}
|
|
|
|
class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
|
|
def __init__(self):
|
|
super().__init__('Model')
|
|
|
|
def refresh(self):
|
|
shared.refresh_checkpoints()
|
|
|
|
def list_reference(self): # pylint: disable=inconsistent-return-statements
|
|
existing = [model.filename if model.type == 'safetensors' else model.name for model in sd_models.checkpoints_list.values()]
|
|
|
|
def reference_downloaded(url):
|
|
url = url.split('@')[0] if '@' in url else 'Diffusers/' + url
|
|
url = url.split('+')[0] if '+' in url else url
|
|
return any(model.endswith(url) for model in existing)
|
|
|
|
if not shared.opts.sd_checkpoint_autodownload or not shared.opts.extra_network_reference_enable:
|
|
shared.log.debug(f'Networks: type="reference" autodownload={shared.opts.sd_checkpoint_autodownload} enable={shared.opts.extra_network_reference_enable}')
|
|
return []
|
|
count = { 'total': 0, 'ready': 0, 'hidden': 0, 'experimental': 0, 'base': 0 }
|
|
|
|
reference_base = readfile(os.path.join('html', 'reference.json'), as_type="dict")
|
|
reference_quant = readfile(os.path.join('html', 'reference-quant.json'), as_type="dict")
|
|
reference_distilled = readfile(os.path.join('html', 'reference-distilled.json'), as_type="dict")
|
|
reference_community = readfile(os.path.join('html', 'reference-community.json'), as_type="dict")
|
|
reference_cloud = readfile(os.path.join('html', 'reference-cloud.json'), as_type="dict")
|
|
shared.reference_models = {}
|
|
shared.reference_models.update(reference_base)
|
|
shared.reference_models.update(reference_quant)
|
|
shared.reference_models.update(reference_community)
|
|
shared.reference_models.update(reference_distilled)
|
|
shared.reference_models.update(reference_cloud)
|
|
|
|
for k, v in shared.reference_models.items():
|
|
count['total'] += 1
|
|
url = v['path']
|
|
experimental = v.get('experimental', False)
|
|
if experimental:
|
|
if shared.cmd_opts.experimental:
|
|
shared.log.debug(f'Networks: experimental model="{k}"')
|
|
count['experimental'] += 1
|
|
else:
|
|
continue
|
|
preview = v.get('preview', v['path'])
|
|
preview_file = self.find_preview_file(os.path.join(paths.reference_path, preview))
|
|
name = os.path.normpath(os.path.join(paths.reference_path, k)).replace('\\', '/')
|
|
size = int(float(v.get('size', 0)) * 1024 * 1024 * 1024)
|
|
mtime = v.get('date', None)
|
|
if mtime is None:
|
|
_size, mtime = modelstats.stat(preview_file)
|
|
else:
|
|
try:
|
|
mtime = datetime.strptime(mtime, '%Y %B') # 2025 January
|
|
except Exception:
|
|
_size, mtime = modelstats.stat(preview_file)
|
|
if len(v.get("subfolder", "")) > 0:
|
|
path = f'{v.get("path", "")}+{v.get("subfolder", "")}'
|
|
else:
|
|
path = f'{v.get("path", "")}'
|
|
|
|
tag = v.get('tags', '')
|
|
if tag in count:
|
|
count[tag] += 1
|
|
elif tag != '':
|
|
count[tag] = 1
|
|
else:
|
|
count['base'] += 1
|
|
|
|
ready = reference_downloaded(url)
|
|
version = "ready" if ready else "download"
|
|
if tag == 'cloud':
|
|
version = 'Cloud'
|
|
if not ready and shared.opts.offline_mode:
|
|
count['hidden'] += 1
|
|
continue
|
|
if ready:
|
|
count['ready'] += 1
|
|
|
|
yield {
|
|
"type": 'Model',
|
|
"name": name,
|
|
"title": name,
|
|
"filename": url,
|
|
"preview": self.find_preview(os.path.join(paths.reference_path, preview)),
|
|
"local_preview": preview_file,
|
|
"onclick": '"' + html.escape(f"selectReference({json.dumps(path)})") + '"',
|
|
"hash": None,
|
|
"mtime": mtime,
|
|
"size": size,
|
|
"info": {},
|
|
"metadata": {},
|
|
"description": v.get('desc', ''),
|
|
"version": version,
|
|
"tags": tag,
|
|
}
|
|
shared.log.debug(f'Networks: type="reference" {count}')
|
|
|
|
def create_item(self, name):
|
|
record = None
|
|
try:
|
|
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoints_list.get(name)
|
|
size, mtime = modelstats.stat(checkpoint.filename)
|
|
record = {
|
|
"type": 'Model',
|
|
"name": checkpoint.name,
|
|
"title": checkpoint.title,
|
|
"filename": checkpoint.filename,
|
|
"hash": checkpoint.shorthash,
|
|
"metadata": checkpoint.metadata,
|
|
"onclick": '"' + html.escape(f"selectCheckpoint({json.dumps(name)})") + '"',
|
|
"mtime": mtime,
|
|
"size": size,
|
|
}
|
|
record['info'] = self.find_info(checkpoint.filename)
|
|
record['description'] = self.find_description(checkpoint.filename, record['info'])
|
|
version = self.find_version(checkpoint, record['info'])
|
|
if 'baseModel' in version:
|
|
record['version'] = version.get("baseModel", "")
|
|
elif '_class_name' in record['info']:
|
|
record['version'] = record['info'].get('_class_name', '').replace('Pipeline', '').replace('Image', '')
|
|
else:
|
|
record['version'] = ''
|
|
record['version'] = version_map.get(record['version'], record['version'])
|
|
|
|
except Exception as e:
|
|
shared.log.debug(f'Networks error: type=model file="{name}" {e}')
|
|
return record
|
|
|
|
def list_items(self):
|
|
items = []
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=shared.max_workers) as executor:
|
|
future_items = {executor.submit(self.create_item, cp): cp for cp in list(sd_models.checkpoints_list.copy())}
|
|
for future in concurrent.futures.as_completed(future_items):
|
|
item = future.result()
|
|
if item is not None:
|
|
items.append(item)
|
|
for record in self.list_reference():
|
|
items.append(record)
|
|
self.update_all_previews(items)
|
|
return items
|
|
|
|
def allowed_directories_for_previews(self):
|
|
return [v for v in [shared.opts.ckpt_dir, paths.reference_path, sd_models.model_path] if v is not None]
|