1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/model_tools.py
2025-03-16 21:45:05 -04:00

82 lines
2.8 KiB
Python

import inspect
import diffusers
import transformers
import safetensors.torch
from modules import shared, devices, model_quant
def remove_entries_after_depth(d, depth, current_depth=0):
if current_depth >= depth:
return None
if isinstance(d, dict):
return {k: remove_entries_after_depth(v, depth, current_depth + 1) for k, v in d.items() if remove_entries_after_depth(v, depth, current_depth + 1) is not None}
return d
def list_compact(flat_list):
result_list = []
for item in flat_list:
keys = item.split('.')
keys = '.'.join(keys[:2])
if keys not in result_list:
result_list.append(keys)
return result_list
def list_to_dict(flat_list):
result_dict = {}
try:
for item in flat_list:
keys = item.split('.')
d = result_dict
for key in keys[:-1]:
d = d.setdefault(key, {})
d[keys[-1]] = None
except Exception:
pass
return result_dict
def get_safetensor_keys(filename):
keys = []
try:
with safetensors.torch.safe_open(filename, framework="pt", device="cpu") as f:
keys = f.keys()
except Exception:
pass
return keys
def get_modules(model: callable):
signature = inspect.signature(model.__init__, follow_wrapped=True)
params = {param.name: param.annotation for param in signature.parameters.values() if param.annotation != inspect._empty and hasattr(param.annotation, 'from_pretrained')} # pylint: disable=protected-access
for name, cls in params.items():
shared.log.debug(f'Analyze: model={model} module={name} class={cls.__name__} loadable={getattr(cls, "from_pretrained", None)}')
return params
def load_modules(repo_id: str, params: dict):
cache_dir = shared.opts.hfcache_dir
modules = {}
for name, cls in params.items():
subfolder = None
kwargs = {}
if cls == diffusers.AutoencoderKL:
subfolder = 'vae'
if cls == transformers.CLIPTextModel: # clip-vit-l
subfolder = 'text_encoder'
if cls == transformers.CLIPTextModelWithProjection: # clip-vit-g
subfolder = 'text_encoder_2'
if cls == transformers.T5EncoderModel: # t5-xxl
subfolder = 'text_encoder_3'
kwargs = model_quant.create_config(kwargs)
kwargs['variant'] = 'fp16'
if cls == diffusers.SD3Transformer2DModel:
subfolder = 'transformer'
kwargs = model_quant.create_config(kwargs)
if subfolder is None:
continue
shared.log.debug(f'Load: module={name} class={cls.__name__} repo={repo_id} location={subfolder}')
modules[name] = cls.from_pretrained(repo_id, subfolder=subfolder, cache_dir=cache_dir, torch_dtype=devices.dtype, **kwargs)
return modules