mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
104 lines
3.2 KiB
Python
Executable File
104 lines
3.2 KiB
Python
Executable File
#!/usr/bin/env python
|
|
import os
|
|
import sys
|
|
from rich import print as pprint
|
|
|
|
|
|
def has(obj, attr, *args):
|
|
import functools
|
|
if not isinstance(obj, dict):
|
|
return False
|
|
def _getattr(obj, attr):
|
|
return obj.get(attr, args) if isinstance(obj, dict) else False
|
|
return functools.reduce(_getattr, [obj] + attr.split('.'))
|
|
|
|
|
|
def remove_entries_after_depth(d, depth, current_depth=0):
|
|
try:
|
|
if current_depth >= depth:
|
|
return None
|
|
if isinstance(d, dict):
|
|
return {k: remove_entries_after_depth(v, depth, current_depth + 1) for k, v in d.items() if remove_entries_after_depth(v, depth, current_depth + 1) is not None}
|
|
except Exception:
|
|
pass
|
|
return d
|
|
|
|
|
|
def list_to_dict(flat_list):
|
|
result_dict = {}
|
|
try:
|
|
for item in flat_list:
|
|
keys = item.split('.')
|
|
d = result_dict
|
|
for key in keys[:-1]:
|
|
d = d.setdefault(key, {})
|
|
d[keys[-1]] = None
|
|
except Exception:
|
|
pass
|
|
return result_dict
|
|
|
|
|
|
def list_compact(flat_list):
|
|
result_list = []
|
|
for item in flat_list:
|
|
keys = item.split('.')
|
|
keys = '.'.join(keys[:2])
|
|
if keys not in result_list:
|
|
result_list.append(keys)
|
|
return result_list
|
|
|
|
|
|
def guess_dct(dct: dict):
|
|
# if has(dct, 'model.diffusion_model.input_blocks') and has(dct, 'model.diffusion_model.label_emb'):
|
|
# return 'sdxl'
|
|
if has(dct, 'model.diffusion_model.input_blocks') and len(list(has(dct, 'model.diffusion_model.input_blocks'))) == 12:
|
|
return 'sd15'
|
|
if has(dct, 'model.diffusion_model.input_blocks') and len(list(has(dct, 'model.diffusion_model.input_blocks'))) == 9:
|
|
return 'sdxl'
|
|
if has(dct, 'model.diffusion_model.joint_blocks') and len(list(has(dct, 'model.diffusion_model.joint_blocks'))) == 24:
|
|
return 'sd35-medium'
|
|
if has(dct, 'model.diffusion_model.joint_blocks') and len(list(has(dct, 'model.diffusion_model.joint_blocks'))) == 38:
|
|
return 'sd35-large'
|
|
if has(dct, 'model.diffusion_model.double_blocks') and len(list(has(dct, 'model.diffusion_model.double_blocks'))) == 19:
|
|
if has(dct, 'model.diffusion_model.distilled_guidance_layer'):
|
|
return 'chroma'
|
|
return 'flux-dev'
|
|
return None
|
|
|
|
|
|
def read_keys(fn):
|
|
if not fn.lower().endswith(".safetensors"):
|
|
return
|
|
from safetensors.torch import safe_open
|
|
keys = []
|
|
try:
|
|
with safe_open(fn, framework="pt", device="cpu") as f:
|
|
keys = f.keys()
|
|
except Exception as e:
|
|
pprint(e)
|
|
dct = list_to_dict(keys)
|
|
lst = list_compact(keys)
|
|
pprint(f'file: {fn}')
|
|
pprint(lst)
|
|
pprint(remove_entries_after_depth(dct, 3))
|
|
pprint(remove_entries_after_depth(dct, 6))
|
|
guess = guess_dct(dct)
|
|
pprint(f'guess: {guess}')
|
|
return keys
|
|
|
|
|
|
def main():
|
|
if len(sys.argv) == 0:
|
|
print('metadata:', 'no files specified')
|
|
for fn in sys.argv:
|
|
if os.path.isfile(fn):
|
|
read_keys(fn)
|
|
elif os.path.isdir(fn):
|
|
for root, _dirs, files in os.walk(fn):
|
|
for file in files:
|
|
read_keys(os.path.join(root, file))
|
|
|
|
if __name__ == '__main__':
|
|
sys.argv.pop(0)
|
|
main()
|