mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
163 lines
5.3 KiB
Python
163 lines
5.3 KiB
Python
import os
|
|
import json
|
|
import importlib
|
|
from typing import Type, Tuple, Union, List, Dict, Any
|
|
import torch
|
|
import diffusers
|
|
|
|
|
|
def extract_device(args: List, kwargs: Dict):
|
|
device = kwargs.get("device", None)
|
|
|
|
if device is None:
|
|
for arg in args:
|
|
if isinstance(arg, torch.device):
|
|
device = arg
|
|
|
|
return device
|
|
|
|
|
|
def move_inference_session(session, device: torch.device): # session: ort.InferenceSession
|
|
from modules.devices import device as default_device
|
|
from modules.devices import backend as default_backend
|
|
|
|
if default_device.type == "cpu" and default_backend != "openvino": # CPU-only torch without any other external ops overriding. This transfer will be led to mistake.
|
|
return session
|
|
|
|
from . import DynamicSessionOptions, TemporalModule
|
|
from .execution_providers import TORCH_DEVICE_TO_EP
|
|
|
|
previous_provider = session._providers # pylint: disable=protected-access
|
|
provider = TORCH_DEVICE_TO_EP[device.type] if device.type in TORCH_DEVICE_TO_EP else previous_provider
|
|
path = session._model_path # pylint: disable=protected-access
|
|
|
|
try:
|
|
return diffusers.OnnxRuntimeModel.load_model(path, provider, DynamicSessionOptions.from_sess_options(session._sess_options)) # pylint: disable=protected-access
|
|
except Exception:
|
|
return TemporalModule(previous_provider, path, session._sess_options) # pylint: disable=protected-access
|
|
|
|
|
|
def check_diffusers_cache(path: os.PathLike):
|
|
from modules.shared import opts
|
|
return opts.diffusers_dir in os.path.abspath(path)
|
|
|
|
|
|
def check_pipeline_sdxl(cls: Type[diffusers.DiffusionPipeline]) -> bool:
|
|
return 'XL' in cls.__name__
|
|
|
|
|
|
def check_cache_onnx(path: os.PathLike) -> bool:
|
|
if not os.path.isdir(path):
|
|
return False
|
|
|
|
init_dict_path = os.path.join(path, "model_index.json")
|
|
|
|
if not os.path.isfile(init_dict_path):
|
|
return False
|
|
|
|
init_dict = None
|
|
|
|
with open(init_dict_path, "r", encoding="utf-8") as file:
|
|
init_dict = file.read()
|
|
|
|
if "OnnxRuntimeModel" not in init_dict:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def load_init_dict(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike):
|
|
merged: Dict[str, Any] = {}
|
|
extracted = cls.extract_init_dict(diffusers.DiffusionPipeline.load_config(path))
|
|
|
|
for item in extracted:
|
|
merged.update(item)
|
|
|
|
merged = merged.items()
|
|
R: Dict[str, Tuple[str]] = {}
|
|
|
|
for k, v in merged:
|
|
if isinstance(v, list):
|
|
if k not in cls.__init__.__annotations__:
|
|
continue
|
|
R[k] = v
|
|
|
|
return R
|
|
|
|
|
|
def load_submodel(path: os.PathLike, is_sdxl: bool, submodel_name: str, item: List[Union[str, None]], **kwargs_ort):
|
|
lib, atr = item
|
|
|
|
if lib is None or atr is None:
|
|
return None
|
|
|
|
library = importlib.import_module(lib)
|
|
attribute = getattr(library, atr)
|
|
path = os.path.join(path, submodel_name)
|
|
|
|
if issubclass(attribute, diffusers.OnnxRuntimeModel):
|
|
return diffusers.OnnxRuntimeModel.load_model(
|
|
os.path.join(path, "model.onnx"),
|
|
**kwargs_ort,
|
|
) if is_sdxl else diffusers.OnnxRuntimeModel.from_pretrained(
|
|
path,
|
|
**kwargs_ort,
|
|
)
|
|
|
|
return attribute.from_pretrained(path)
|
|
|
|
|
|
def load_submodels(path: os.PathLike, is_sdxl: bool, init_dict: Dict[str, Type], **kwargs_ort):
|
|
loaded = {}
|
|
|
|
for k, v in init_dict.items():
|
|
if not isinstance(v, list):
|
|
loaded[k] = v
|
|
continue
|
|
try:
|
|
loaded[k] = load_submodel(path, is_sdxl, k, v, **kwargs_ort)
|
|
except Exception:
|
|
pass
|
|
|
|
return loaded
|
|
|
|
|
|
def load_pipeline(cls: Type[diffusers.DiffusionPipeline], path: os.PathLike, **kwargs_ort) -> diffusers.DiffusionPipeline:
|
|
if os.path.isdir(path):
|
|
return cls(**patch_kwargs(cls, load_submodels(path, check_pipeline_sdxl(cls), load_init_dict(cls, path), **kwargs_ort)))
|
|
else:
|
|
return cls.from_single_file(path)
|
|
|
|
|
|
def patch_kwargs(cls: Type[diffusers.DiffusionPipeline], kwargs: Dict) -> Dict:
|
|
if cls == diffusers.OnnxStableDiffusionPipeline or cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline:
|
|
kwargs["safety_checker"] = None
|
|
kwargs["requires_safety_checker"] = False
|
|
|
|
if cls == diffusers.OnnxStableDiffusionXLPipeline or cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline:
|
|
kwargs["config"] = {}
|
|
|
|
return kwargs
|
|
|
|
|
|
def get_base_constructor(cls: Type[diffusers.DiffusionPipeline], is_refiner: bool):
|
|
if cls == diffusers.OnnxStableDiffusionImg2ImgPipeline or cls == diffusers.OnnxStableDiffusionInpaintPipeline:
|
|
return diffusers.OnnxStableDiffusionPipeline
|
|
|
|
if cls == diffusers.OnnxStableDiffusionXLImg2ImgPipeline and not is_refiner:
|
|
return diffusers.OnnxStableDiffusionXLPipeline
|
|
|
|
return cls
|
|
|
|
|
|
def get_io_config(submodel: str, is_sdxl: bool):
|
|
from modules.paths import sd_configs_path
|
|
|
|
with open(os.path.join(sd_configs_path, "olive", 'sdxl' if is_sdxl else 'sd', f"{submodel}.json"), "r", encoding="utf-8") as config_file:
|
|
io_config: Dict[str, Any] = json.load(config_file)["input_model"]["config"]["io_config"]
|
|
|
|
for axe in io_config["dynamic_axes"]:
|
|
io_config["dynamic_axes"][axe] = { int(k): v for k, v in io_config["dynamic_axes"][axe].items() }
|
|
|
|
return io_config
|