1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/modelstats.py
Vladimir Mandic 8b698ed67f upadte qwen pruning and allow hf models in subfolders
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-10-04 15:49:20 -04:00

112 lines
3.5 KiB
Python

import os
from datetime import datetime
import torch
from modules import shared, sd_models
def walk(folder: str):
files = []
for root, _, filenames in os.walk(folder):
for filename in filenames:
files.append(os.path.join(root, filename))
return files
def stat(fn: str):
if fn is None or len(fn) == 0 or not os.path.exists(fn):
return 0, datetime.fromtimestamp(0)
fs_stat = os.stat(fn, follow_symlinks=False)
mtime = datetime.fromtimestamp(fs_stat.st_mtime).replace(microsecond=0)
if os.path.islink(fn):
size = 0
elif os.path.isfile(fn):
size = round(fs_stat.st_size)
elif os.path.isdir(fn):
size = round(sum(stat(fn)[0] for fn in walk(fn)))
else:
size = 0
return size, mtime
class Module():
name: str = ''
cls: str = None
device: str = None
dtype: str = None
params: int = 0
modules: int = 0
quant: str = None
config: dict = None
def __init__(self, name, module):
self.name = name
self.cls = module.__class__.__name__
if isinstance(module, tuple):
self.cls = module[1]
if hasattr(module, 'config'):
self.config = module.config
if isinstance(module, torch.nn.Module):
self.device = getattr(module, 'device', None)
self.dtype = getattr(module, 'dtype', None)
self.params = sum(p.numel() for p in module.parameters(recurse=True))
self.modules = len(list(module.modules()))
self.quant = getattr(module, 'quantization_method', None)
def __repr__(self):
s = f'name="{self.name}" cls={self.cls} config={self.config is not None}'
if self.device or self.dtype:
s += f' device={self.device} dtype={self.dtype}'
if self.params or self.modules:
s += f' params={self.params} modules={self.modules}'
return s
class Model():
name: str = ''
fn: str = ''
type: str = ''
cls: str = ''
hash: str = ''
meta: dict = {}
size: int = 0
mtime: datetime = None
info: sd_models.CheckpointInfo = None
modules: list[Module] = []
def __init__(self, name):
self.name = name
if not shared.sd_loaded:
return
self.cls = shared.sd_model.__class__.__name__
self.type = shared.sd_model_type
self.info = sd_models.get_closest_checkpoint_match(name)
if self.info is not None:
self.name = self.info.name or self.name
self.hash = self.info.shorthash or ''
self.meta = self.info.metadata or {}
self.size, self.mtime = stat(self.info.filename)
def __repr__(self):
return f'model="{self.name}" type={self.type} class={self.cls} size={self.size} mtime="{self.mtime}" modules={self.modules}'
def analyze():
if not shared.sd_loaded:
return None
model = Model(shared.opts.sd_model_checkpoint)
if model.cls == '':
return model
if hasattr(shared.sd_model, '_internal_dict'):
keys = shared.sd_model._internal_dict.keys() # pylint: disable=protected-access
else:
keys = sd_models.get_signature(shared.sd_model).keys()
model.modules.clear()
for k in keys: # pylint: disable=protected-access
if k.startswith('_'):
continue
component = getattr(shared.sd_model, k, None)
module = Module(k, component)
model.modules.append(module)
shared.log.debug(f'Analyzed: {model}')
return model