mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
42 lines
1.7 KiB
Python
42 lines
1.7 KiB
Python
import time
|
|
import torch
|
|
import diffusers
|
|
from huggingface_hub import hf_hub_download
|
|
from safetensors.torch import load_file
|
|
from modules import shared, devices
|
|
|
|
|
|
decoder_id = "ostris/vae-kl-f8-d16"
|
|
adapter_id = "ostris/16ch-VAE-Adapters"
|
|
|
|
|
|
def load_vae(pipe):
|
|
if shared.sd_model_type == 'sd':
|
|
adapter_file = "16ch-VAE-Adapter-SD15-alpha.safetensors"
|
|
elif shared.sd_model_type == 'sdxl':
|
|
adapter_file = "16ch-VAE-Adapter-SDXL-alpha_v02.safetensors"
|
|
else:
|
|
shared.log.error('VAE: type=osiris unsupported model type')
|
|
return
|
|
t0 = time.time()
|
|
ckpt_file = hf_hub_download(adapter_id, adapter_file, cache_dir=shared.opts.hfcache_dir)
|
|
ckpt = load_file(ckpt_file)
|
|
lora_state_dict = {k: v for k, v in ckpt.items() if "lora" in k}
|
|
unet_state_dict = {k.replace("unet_", ""): v for k, v in ckpt.items() if "unet_" in k}
|
|
|
|
pipe.unet.conv_in = torch.nn.Conv2d(16, 320, 3, 1, 1)
|
|
pipe.unet.conv_out = torch.nn.Conv2d(320, 16, 3, 1, 1)
|
|
pipe.unet.load_state_dict(unet_state_dict, strict=False)
|
|
pipe.unet.conv_in.to(devices.dtype)
|
|
pipe.unet.conv_out.to(devices.dtype)
|
|
pipe.unet.config.in_channels = 16
|
|
pipe.unet.config.out_channels = 16
|
|
|
|
pipe.load_lora_weights(lora_state_dict, adapter_name=adapter_id)
|
|
# pipe.set_adapters(adapter_names=[adapter_id], adapter_weights=[0.8])
|
|
pipe.fuse_lora(adapter_names=[adapter_id], lora_scale=0.8, fuse_unet=True)
|
|
|
|
pipe.vae = diffusers.AutoencoderKL.from_pretrained(decoder_id, torch_dtype=devices.dtype, cache_dir=shared.opts.hfcache_dir)
|
|
t1 = time.time()
|
|
shared.log.info(f'VAE load: type=osiris decoder="{decoder_id}" adapter="{adapter_id}" time={t1-t0:.2f}s')
|