mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
303 lines
10 KiB
Python
303 lines
10 KiB
Python
import os
|
|
from typing import Type, Callable, TypeVar, Dict, Any
|
|
import torch
|
|
import diffusers
|
|
from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextModelWithProjection
|
|
|
|
|
|
class ENVStore:
|
|
__DESERIALIZER: Dict[Type, Callable[[str,], Any]] = {
|
|
bool: lambda x: bool(int(x)),
|
|
int: int,
|
|
str: lambda x: x,
|
|
}
|
|
__SERIALIZER: Dict[Type, Callable[[Any,], str]] = {
|
|
bool: lambda x: str(int(x)),
|
|
int: str,
|
|
str: lambda x: x,
|
|
}
|
|
|
|
def __getattr__(self, name: str):
|
|
value = os.environ.get(f"SDNEXT_OLIVE_{name}", None)
|
|
if value is None:
|
|
return value
|
|
ty = self.__class__.__annotations__[name]
|
|
deserialize = self.__DESERIALIZER[ty]
|
|
return deserialize(value)
|
|
|
|
def __setattr__(self, name: str, value) -> None:
|
|
if name not in self.__class__.__annotations__:
|
|
return
|
|
ty = self.__class__.__annotations__[name]
|
|
serialize = self.__SERIALIZER[ty]
|
|
os.environ[f"SDNEXT_OLIVE_{name}"] = serialize(value)
|
|
|
|
def __delattr__(self, name: str) -> None:
|
|
if name not in self.__class__.__annotations__:
|
|
return
|
|
key = f"SDNEXT_OLIVE_{name}"
|
|
if key not in os.environ:
|
|
return
|
|
os.environ.pop(key)
|
|
|
|
|
|
class OliveOptimizerConfig(ENVStore):
|
|
from_diffusers_cache: bool
|
|
|
|
is_sdxl: bool
|
|
|
|
vae: str
|
|
vae_sdxl_fp16_fix: bool
|
|
|
|
width: int
|
|
height: int
|
|
batch_size: int
|
|
|
|
cross_attention_dim: int
|
|
time_ids_size: int
|
|
|
|
|
|
config = OliveOptimizerConfig()
|
|
|
|
|
|
def get_variant():
|
|
from modules.shared import opts
|
|
|
|
if opts.diffusers_model_load_variant == 'default':
|
|
from modules import devices
|
|
|
|
if devices.dtype == torch.float16:
|
|
return 'fp16'
|
|
|
|
return None
|
|
elif opts.diffusers_model_load_variant == 'fp32':
|
|
return None
|
|
else:
|
|
return opts.diffusers_model_load_variant
|
|
|
|
|
|
def get_loader_arguments(no_variant: bool = False):
|
|
kwargs = {}
|
|
|
|
if config.from_diffusers_cache:
|
|
from modules.shared import opts
|
|
kwargs["cache_dir"] = opts.diffusers_dir
|
|
if not no_variant:
|
|
kwargs["variant"] = get_variant()
|
|
|
|
return kwargs
|
|
|
|
|
|
T = TypeVar("T")
|
|
def from_pretrained(cls: Type[T], pretrained_model_name_or_path: os.PathLike, *args, no_variant: bool = False, **kwargs) -> T:
|
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
|
if pretrained_model_name_or_path.endswith(".onnx"):
|
|
cls = diffusers.OnnxRuntimeModel
|
|
pretrained_model_name_or_path = os.path.dirname(pretrained_model_name_or_path)
|
|
return cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs, **get_loader_arguments(no_variant))
|
|
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
# --------------------------------------------------------------------------
|
|
|
|
|
|
# Helper latency-only dataloader that creates random tensors with no label
|
|
class RandomDataLoader:
|
|
def __init__(self, create_inputs_func, batchsize, torch_dtype):
|
|
self.create_input_func = create_inputs_func
|
|
self.batchsize = batchsize
|
|
self.torch_dtype = torch_dtype
|
|
|
|
def __getitem__(self, idx):
|
|
label = None
|
|
return self.create_input_func(self.batchsize, self.torch_dtype), label
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# TEXT ENCODER
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def text_encoder_inputs(batchsize, torch_dtype): # pylint: disable=unused-argument
|
|
input_ids = torch.zeros((config.batch_size, 77), dtype=torch_dtype)
|
|
return {
|
|
"input_ids": input_ids,
|
|
"output_hidden_states": True,
|
|
} if config.is_sdxl else input_ids
|
|
|
|
|
|
def text_encoder_load(model_name):
|
|
model = from_pretrained(CLIPTextModel, model_name, subfolder="text_encoder")
|
|
return model
|
|
|
|
|
|
def text_encoder_conversion_inputs(model): # pylint: disable=unused-argument
|
|
return text_encoder_inputs(1, torch.int32)
|
|
|
|
|
|
def text_encoder_data_loader(data_dir, batchsize, *_, **__): # pylint: disable=unused-argument
|
|
return RandomDataLoader(text_encoder_inputs, config.batch_size, torch.int32)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# TEXT ENCODER 2
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def text_encoder_2_inputs(batchsize, torch_dtype): # pylint: disable=unused-argument
|
|
return {
|
|
"input_ids": torch.zeros((config.batch_size, 77), dtype=torch_dtype),
|
|
"output_hidden_states": True,
|
|
}
|
|
|
|
|
|
def text_encoder_2_load(model_name):
|
|
model = from_pretrained(CLIPTextModelWithProjection, model_name, subfolder="text_encoder_2")
|
|
return model
|
|
|
|
|
|
def text_encoder_2_conversion_inputs(model): # pylint: disable=unused-argument
|
|
return text_encoder_2_inputs(1, torch.int64)
|
|
|
|
|
|
def text_encoder_2_data_loader(data_dir, batchsize, *_, **__): # pylint: disable=unused-argument
|
|
return RandomDataLoader(text_encoder_2_inputs, config.batch_size, torch.int64)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# UNET
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def unet_inputs(batchsize, torch_dtype, is_conversion_inputs=False): # pylint: disable=unused-argument
|
|
if config.is_sdxl:
|
|
inputs = {
|
|
"sample": torch.rand((2 * config.batch_size, 4, config.height // 8, config.width // 8), dtype=torch_dtype),
|
|
"timestep": torch.rand((1,), dtype=torch_dtype),
|
|
"encoder_hidden_states": torch.rand((2 * config.batch_size, 77, config.cross_attention_dim), dtype=torch_dtype),
|
|
}
|
|
|
|
if is_conversion_inputs:
|
|
inputs["additional_inputs"] = {
|
|
"added_cond_kwargs": {
|
|
"text_embeds": torch.rand((2 * config.batch_size, 1280), dtype=torch_dtype),
|
|
"time_ids": torch.rand((2 * config.batch_size, config.time_ids_size), dtype=torch_dtype),
|
|
}
|
|
}
|
|
else:
|
|
inputs["text_embeds"] = torch.rand((2 * config.batch_size, 1280), dtype=torch_dtype)
|
|
inputs["time_ids"] = torch.rand((2 * config.batch_size, config.time_ids_size), dtype=torch_dtype)
|
|
else:
|
|
inputs = {
|
|
"sample": torch.rand((config.batch_size, 4, config.height // 8, config.width // 8), dtype=torch_dtype),
|
|
"timestep": torch.rand((config.batch_size,), dtype=torch_dtype),
|
|
"encoder_hidden_states": torch.rand((config.batch_size, 77, config.cross_attention_dim), dtype=torch_dtype),
|
|
}
|
|
|
|
# use as kwargs since they won't be in the correct position if passed along with the tuple of inputs
|
|
kwargs = {
|
|
"return_dict": False,
|
|
}
|
|
if is_conversion_inputs:
|
|
inputs["additional_inputs"] = {
|
|
**kwargs,
|
|
"added_cond_kwargs": {
|
|
"text_embeds": torch.rand((1, 1280), dtype=torch_dtype),
|
|
"time_ids": torch.rand((1, 5), dtype=torch_dtype),
|
|
},
|
|
}
|
|
else:
|
|
inputs.update(kwargs)
|
|
inputs["onnx::Concat_4"] = torch.rand((1, 1280), dtype=torch_dtype)
|
|
inputs["onnx::Shape_5"] = torch.rand((1, 5), dtype=torch_dtype)
|
|
|
|
return inputs
|
|
|
|
|
|
def unet_load(model_name):
|
|
model = from_pretrained(diffusers.UNet2DConditionModel, model_name, subfolder="unet")
|
|
return model
|
|
|
|
|
|
def unet_conversion_inputs(model): # pylint: disable=unused-argument
|
|
return tuple(unet_inputs(1, torch.float32, True).values())
|
|
|
|
|
|
def unet_data_loader(data_dir, batchsize, *_, **__): # pylint: disable=unused-argument
|
|
return RandomDataLoader(unet_inputs, config.batch_size, torch.float16)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# VAE ENCODER
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def vae_encoder_inputs(batchsize, torch_dtype): # pylint: disable=unused-argument
|
|
return {
|
|
"sample": torch.rand((config.batch_size, 3, config.height, config.width), dtype=torch_dtype),
|
|
"return_dict": False,
|
|
}
|
|
|
|
|
|
def vae_encoder_load(model_name):
|
|
subfolder = "vae_encoder" if os.path.isdir(os.path.join(model_name, "vae_encoder")) else "vae"
|
|
|
|
if config.vae_sdxl_fp16_fix:
|
|
model_name = "madebyollin/sdxl-vae-fp16-fix"
|
|
subfolder = ""
|
|
|
|
if config.vae is None:
|
|
model = from_pretrained(diffusers.AutoencoderKL, model_name, subfolder=subfolder, no_variant=config.vae_sdxl_fp16_fix)
|
|
else:
|
|
model = diffusers.AutoencoderKL.from_single_file(config.vae)
|
|
|
|
model.forward = lambda sample, return_dict: model.encode(sample, return_dict)[0].sample()
|
|
|
|
return model
|
|
|
|
|
|
def vae_encoder_conversion_inputs(model): # pylint: disable=unused-argument
|
|
return tuple(vae_encoder_inputs(1, torch.float32).values())
|
|
|
|
|
|
def vae_encoder_data_loader(data_dir, batchsize, *_, **__): # pylint: disable=unused-argument
|
|
return RandomDataLoader(vae_encoder_inputs, config.batch_size, torch.float16)
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# VAE DECODER
|
|
# -----------------------------------------------------------------------------
|
|
|
|
|
|
def vae_decoder_inputs(batchsize, torch_dtype): # pylint: disable=unused-argument
|
|
return {
|
|
"latent_sample": torch.rand((config.batch_size, 4, config.height // 8, config.width // 8), dtype=torch_dtype),
|
|
"return_dict": False,
|
|
}
|
|
|
|
|
|
def vae_decoder_load(model_name):
|
|
subfolder = "vae_decoder" if os.path.isdir(os.path.join(model_name, "vae_decoder")) else "vae"
|
|
|
|
if config.vae_sdxl_fp16_fix:
|
|
model_name = "madebyollin/sdxl-vae-fp16-fix"
|
|
subfolder = ""
|
|
|
|
if config.vae is None:
|
|
model = from_pretrained(diffusers.AutoencoderKL, model_name, subfolder=subfolder, no_variant=config.vae_sdxl_fp16_fix)
|
|
else:
|
|
model = diffusers.AutoencoderKL.from_single_file(config.vae)
|
|
|
|
model.forward = model.decode
|
|
|
|
return model
|
|
|
|
|
|
def vae_decoder_conversion_inputs(model): # pylint: disable=unused-argument
|
|
return tuple(vae_decoder_inputs(1, torch.float32).values())
|
|
|
|
|
|
def vae_decoder_data_loader(data_dir, batchsize, *_, **__): # pylint: disable=unused-argument
|
|
return RandomDataLoader(vae_decoder_inputs, config.batch_size, torch.float16)
|