1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/sd_vae_ostris.py
Vladimir Mandic b569d68b4f add ostris: experimental
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2024-10-26 21:51:36 -04:00

42 lines
1.7 KiB
Python

import time
import torch
import diffusers
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from modules import shared, devices
decoder_id = "ostris/vae-kl-f8-d16"
adapter_id = "ostris/16ch-VAE-Adapters"
def load_vae(pipe):
if shared.sd_model_type == 'sd':
adapter_file = "16ch-VAE-Adapter-SD15-alpha.safetensors"
elif shared.sd_model_type == 'sdxl':
adapter_file = "16ch-VAE-Adapter-SDXL-alpha_v02.safetensors"
else:
shared.log.error('VAE: type=osiris unsupported model type')
return
t0 = time.time()
ckpt_file = hf_hub_download(adapter_id, adapter_file, cache_dir=shared.opts.hfcache_dir)
ckpt = load_file(ckpt_file)
lora_state_dict = {k: v for k, v in ckpt.items() if "lora" in k}
unet_state_dict = {k.replace("unet_", ""): v for k, v in ckpt.items() if "unet_" in k}
pipe.unet.conv_in = torch.nn.Conv2d(16, 320, 3, 1, 1)
pipe.unet.conv_out = torch.nn.Conv2d(320, 16, 3, 1, 1)
pipe.unet.load_state_dict(unet_state_dict, strict=False)
pipe.unet.conv_in.to(devices.dtype)
pipe.unet.conv_out.to(devices.dtype)
pipe.unet.config.in_channels = 16
pipe.unet.config.out_channels = 16
pipe.load_lora_weights(lora_state_dict, adapter_name=adapter_id)
# pipe.set_adapters(adapter_names=[adapter_id], adapter_weights=[0.8])
pipe.fuse_lora(adapter_names=[adapter_id], lora_scale=0.8, fuse_unet=True)
pipe.vae = diffusers.AutoencoderKL.from_pretrained(decoder_id, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir)
t1 = time.time()
shared.log.info(f'VAE load: type=osiris decoder="{decoder_id}" adapter="{adapter_id}" time={t1-t0:.2f}s')