1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/sd_vae_repa.py
Vladimir Mandic fbd24c290a experiments with repa-e
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-10-27 15:28:20 -04:00

33 lines
1.1 KiB
Python

import diffusers
from modules import shared
models = {
'sd': { 'repo_id': 'REPA-E/e2e-sdvae-hf', 'cls': 'AutoencoderKL' },
'sdxl': { 'repo_id': 'REPA-E/e2e-sdvae-hf', 'cls': 'AutoencoderKL' },
'sd3': { 'repo_id': 'REPA-E/e2e-sd3.5-vae', 'cls': 'AutoencoderKL' },
'f1': { 'repo_id': 'REPA-E/e2e-flux-vae', 'cls': 'AutoencoderKL' },
'qwen': { 'repo_id': 'REPA-E/e2e-qwenimage-vae', 'cls': 'AutoencoderKLQwenImage' },
}
loaded_cls = None
loaded_vae = None
def repa_load(latents):
global loaded_cls, loaded_vae # pylint: disable=global-statement
config = models.get(shared.sd_model_type, None)
if config is None:
shared.log.error(f'Decode: type="repa" model={shared.sd_model_type} not supported')
return latents
cls = getattr(diffusers, config['cls'])
if (cls != loaded_cls) or (loaded_vae is None):
shared.log.info(f'RePA VAE load: {config["repo_id"]} cls={config["cls"]}')
loaded_vae = cls.from_pretrained(
config['repo_id'],
torch_dtype=latents.dtype,
cache_dir=shared.opts.hfcache_dir,
)
loaded_cls = cls
return loaded_vae