1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/sdnq/file_loader.py

65 lines
2.6 KiB
Python

import re
import concurrent.futures
import torch
def map_keys(key: str, key_mapping: dict) -> str:
new_key = key
if key_mapping:
for pattern, replacement in key_mapping.items():
new_key, n_replace = re.subn(pattern, replacement, new_key)
if n_replace > 0:
break
return new_key
def load_safetensors(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu") -> dict:
from safetensors.torch import safe_open
if state_dict is None:
state_dict = {}
for fn in files:
with safe_open(fn, framework="pt", device=str(device)) as f:
for key in f.keys():
state_dict[map_keys(key, key_mapping)] = f.get_tensor(key)
def load_threaded(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu") -> dict:
future_items = {}
if state_dict is None:
state_dict = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
for fn in files:
future_items[executor.submit(load_safetensors, [fn], key_mapping=key_mapping, device=device, state_dict=state_dict)] = fn
for future in concurrent.futures.as_completed(future_items):
future.result()
def load_streamer(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu") -> dict:
# requires pip install runai_model_streamer
from runai_model_streamer import SafetensorsStreamer
if state_dict is None:
state_dict = {}
with SafetensorsStreamer() as streamer:
streamer.stream_files(files)
for key, tensor in streamer.get_tensors():
state_dict[map_keys(key, key_mapping)] = tensor.to(device)
def load_files(files: list[str], state_dict: dict = None, key_mapping: dict = None, device: torch.device = "cpu", method: str = None) -> dict:
# note: files is list-of-files within a module for chunked loading, not accross model
if isinstance(files, str):
files = [files]
if method is None:
method = "safetensors"
if state_dict is None:
state_dict = {}
if method == "safetensors":
load_safetensors(files, state_dict=state_dict, key_mapping=key_mapping, device=device)
elif method == "threaded":
load_threaded(files, state_dict=state_dict, key_mapping=key_mapping, device=device)
elif method == "streamer":
load_streamer(files, state_dict=state_dict, key_mapping=key_mapping, device=device)
else:
raise ValueError(f"Unsupported loading method: {method}")
return state_dict