mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
33 lines
1.1 KiB
Python
33 lines
1.1 KiB
Python
import diffusers
|
|
from modules import shared
|
|
|
|
|
|
models = {
|
|
'sd': { 'repo_id': 'REPA-E/e2e-sdvae-hf', 'cls': 'AutoencoderKL' },
|
|
'sdxl': { 'repo_id': 'REPA-E/e2e-sdvae-hf', 'cls': 'AutoencoderKL' },
|
|
'sd3': { 'repo_id': 'REPA-E/e2e-sd3.5-vae', 'cls': 'AutoencoderKL' },
|
|
'f1': { 'repo_id': 'REPA-E/e2e-flux-vae', 'cls': 'AutoencoderKL' },
|
|
'qwen': { 'repo_id': 'REPA-E/e2e-qwenimage-vae', 'cls': 'AutoencoderKLQwenImage' },
|
|
}
|
|
loaded_cls = None
|
|
loaded_vae = None
|
|
|
|
|
|
def repa_load(latents):
|
|
global loaded_cls, loaded_vae # pylint: disable=global-statement
|
|
config = models.get(shared.sd_model_type, None)
|
|
if config is None:
|
|
shared.log.error(f'Decode: type="repa" model={shared.sd_model_type} not supported')
|
|
return latents
|
|
|
|
cls = getattr(diffusers, config['cls'])
|
|
if (cls != loaded_cls) or (loaded_vae is None):
|
|
shared.log.info(f'RePA VAE load: {config["repo_id"]} cls={config["cls"]}')
|
|
loaded_vae = cls.from_pretrained(
|
|
config['repo_id'],
|
|
torch_dtype=latents.dtype,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
)
|
|
loaded_cls = cls
|
|
return loaded_vae
|