1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-29 05:02:09 +03:00
Files
sdnext/modules/modelstats.py
Vladimir Mandic cf9fe20803 omnigen!
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2024-10-22 12:10:53 -04:00

86 lines
2.8 KiB
Python

import os
from datetime import datetime
import torch
from modules import shared, sd_models
class Module():
name: str = ''
cls: str = None
device: str = None
dtype: str = None
params: int = 0
modules: int = 0
config: dict = None
def __init__(self, name, module):
self.name = name
self.cls = module.__class__.__name__
if isinstance(module, tuple):
self.cls = module[1]
if hasattr(module, 'config'):
self.config = module.config
if isinstance(module, torch.nn.Module):
self.device = getattr(module, 'device', None)
self.dtype = getattr(module, 'dtype', None)
self.params = sum(p.numel() for p in module.parameters(recurse=True))
self.modules = len(list(module.modules()))
def __repr__(self):
s = f'name="{self.name}" cls={self.cls} config={self.config is not None}'
if self.device or self.dtype:
s += f' device={self.device} dtype={self.dtype}'
if self.params or self.modules:
s += f' params={self.params} modules={self.modules}'
return s
class Model():
name: str = ''
fn: str = ''
type: str = ''
cls: str = ''
hash: str = ''
meta: dict = {}
size: int = 0
mtime: datetime = None
info: sd_models.CheckpointInfo = None
modules: list[Module] = []
def __init__(self, name):
self.name = name
if not shared.sd_loaded:
return
self.cls = shared.sd_model.__class__.__name__
self.type = shared.sd_model_type
self.info = sd_models.get_closet_checkpoint_match(name)
if self.info is not None:
self.name = self.info.name or self.name
self.hash = self.info.shorthash or ''
self.meta = self.info.metadata or {}
if os.path.exists(self.info.filename):
stat = os.stat(self.info.filename)
self.mtime = datetime.fromtimestamp(stat.st_mtime).replace(microsecond=0)
if os.path.isfile(self.info.filename):
self.size = round(stat.st_size)
def __repr__(self):
return f'model="{self.name}" type={self.type} class={self.cls} size={self.size} mtime="{self.mtime}" modules={self.modules}'
def analyze():
model = Model(shared.opts.sd_model_checkpoint)
if model.cls == '':
return model
if hasattr(shared.sd_model, '_internal_dict'):
keys = shared.sd_model._internal_dict.keys() # pylint: disable=protected-access
else:
keys = sd_models.get_signature(shared.sd_model).keys()
model.modules.clear()
for k in keys: # pylint: disable=protected-access
component = getattr(shared.sd_model, k, None)
module = Module(k, component)
model.modules.append(module)
shared.log.debug(f'Analyzed: {model}')
return model