mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
216 lines
10 KiB
Python
216 lines
10 KiB
Python
import os
|
|
import sys
|
|
import copy
|
|
import time
|
|
import transformers # pylint: disable=unused-import
|
|
import diffusers
|
|
from modules import shared, errors, sd_models, sd_checkpoint, model_quant, devices, sd_hijack_te, sd_hijack_vae
|
|
from modules.video_models import models_def, video_utils, video_overrides, video_cache
|
|
|
|
|
|
def _loader(component):
|
|
"""Return loader type for log messages."""
|
|
if sys.platform != 'linux':
|
|
return 'default'
|
|
if component == 'diffusers':
|
|
return 'runai' if shared.opts.runai_streamer_diffusers else 'default'
|
|
return 'runai' if shared.opts.runai_streamer_transformers else 'default'
|
|
|
|
|
|
loaded_model = None
|
|
|
|
|
|
def load_custom(model_name: str):
|
|
shared.log.debug(f'Video load: module=pipe repo="{model_name}" cls=Custom')
|
|
if 'veo-3.1' in model_name:
|
|
from modules.video_models.google_veo import load_veo
|
|
pipe = load_veo(model_name)
|
|
return pipe
|
|
return None
|
|
|
|
|
|
def load_model(selected: models_def.Model):
|
|
if selected is None or selected.repo is None:
|
|
return ''
|
|
global loaded_model # pylint: disable=global-statement
|
|
if not shared.sd_loaded:
|
|
loaded_model = None
|
|
if loaded_model == selected.name:
|
|
return ''
|
|
if shared.sd_loaded:
|
|
sd_models.unload_model_weights()
|
|
|
|
t0 = time.time()
|
|
jobid = shared.state.begin('Load model')
|
|
|
|
video_cache.apply_teacache_patch(selected.dit_cls)
|
|
|
|
# overrides
|
|
offline_args = {}
|
|
if shared.opts.offline_mode:
|
|
offline_args["local_files_only"] = True
|
|
os.environ['HF_HUB_OFFLINE'] = '1'
|
|
else:
|
|
os.environ.pop('HF_HUB_OFFLINE', None)
|
|
os.unsetenv('HF_HUB_OFFLINE')
|
|
|
|
kwargs = video_overrides.load_override(selected, **offline_args)
|
|
|
|
# text encoder
|
|
if selected.te_cls is not None:
|
|
try:
|
|
load_args, quant_args = model_quant.get_dit_args({}, module='TE', device_map=True)
|
|
|
|
# loader deduplication of text-encoder models
|
|
if selected.te_cls.__name__ == 'T5EncoderModel' and shared.opts.te_shared_t5:
|
|
selected.te = 'Disty0/t5-xxl'
|
|
selected.te_folder = ''
|
|
selected.te_revision = None
|
|
if selected.te_cls.__name__ == 'UMT5EncoderModel' and shared.opts.te_shared_t5:
|
|
if 'SDNQ' in selected.name:
|
|
selected.te = 'Disty0/Wan2.2-T2V-A14B-SDNQ-uint4-svd-r32'
|
|
else:
|
|
selected.te = 'Wan-AI/Wan2.2-TI2V-5B-Diffusers'
|
|
selected.te_folder = 'text_encoder'
|
|
selected.te_revision = None
|
|
if selected.te_cls.__name__ == 'LlamaModel' and shared.opts.te_shared_t5:
|
|
selected.te = 'hunyuanvideo-community/HunyuanVideo'
|
|
selected.te_folder = 'text_encoder'
|
|
selected.te_revision = None
|
|
if selected.te_cls.__name__ == 'Qwen2_5_VLForConditionalGeneration' and shared.opts.te_shared_t5:
|
|
selected.te = 'ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers'
|
|
selected.te_folder = 'text_encoder'
|
|
selected.te_revision = None
|
|
|
|
shared.log.debug(f'Video load: module=te repo="{selected.te or selected.repo}" folder="{selected.te_folder}" cls={selected.te_cls.__name__} quant={model_quant.get_quant_type(quant_args)} loader={_loader("transformers")}')
|
|
kwargs["text_encoder"] = selected.te_cls.from_pretrained(
|
|
pretrained_model_name_or_path=selected.te or selected.repo,
|
|
subfolder=selected.te_folder,
|
|
revision=selected.te_revision or selected.repo_revision,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**load_args,
|
|
**quant_args,
|
|
**offline_args,
|
|
)
|
|
except Exception as e:
|
|
shared.log.error(f'video load: module=te cls={selected.te_cls.__name__} {e}')
|
|
errors.display(e, 'video')
|
|
|
|
# transformer
|
|
if selected.dit_cls is not None:
|
|
try:
|
|
def load_dit_folder(dit_folder):
|
|
if dit_folder is not None and dit_folder not in kwargs:
|
|
# get a new quant arg on every loop to prevent the quant config classes getting entangled
|
|
load_args, quant_args = model_quant.get_dit_args({}, module='Model', device_map=True)
|
|
shared.log.debug(f'Video load: module=transformer repo="{selected.dit or selected.repo}" module="{dit_folder}" folder="{dit_folder}" cls={selected.dit_cls.__name__} quant={model_quant.get_quant_type(quant_args)} loader={_loader("diffusers")}')
|
|
kwargs[dit_folder] = selected.dit_cls.from_pretrained(
|
|
pretrained_model_name_or_path=selected.dit or selected.repo,
|
|
subfolder=dit_folder,
|
|
revision=selected.dit_revision or selected.repo_revision,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
**load_args,
|
|
**quant_args,
|
|
**offline_args,
|
|
)
|
|
else:
|
|
shared.log.debug(f'Video load: module=transformer repo="{selected.dit or selected.repo}" module="{dit_folder}" folder="{dit_folder}" cls={selected.dit_cls.__name__} loader={_loader("diffusers")} skip')
|
|
|
|
if selected.dit_folder is None:
|
|
selected.dit_folder = ['transformer']
|
|
if isinstance(selected.dit_folder, list) or isinstance(selected.dit_folder, tuple):
|
|
for dit_folder in selected.dit_folder: # wan a14b has transformer and transformer_2
|
|
load_dit_folder(dit_folder)
|
|
else:
|
|
load_dit_folder(selected.dit_folder)
|
|
except Exception as e:
|
|
shared.log.error(f'video load: module=transformer cls={selected.dit_cls.__name__} {e}')
|
|
errors.display(e, 'video')
|
|
|
|
# model
|
|
try:
|
|
if selected.repo_cls is None:
|
|
shared.sd_model = load_custom(selected.repo)
|
|
else:
|
|
shared.log.debug(f'Video load: module=pipe repo="{selected.repo}" cls={selected.repo_cls.__name__}')
|
|
shared.sd_model = selected.repo_cls.from_pretrained(
|
|
pretrained_model_name_or_path=selected.repo,
|
|
revision=selected.repo_revision,
|
|
cache_dir=shared.opts.hfcache_dir,
|
|
torch_dtype=devices.dtype,
|
|
**kwargs,
|
|
**offline_args,
|
|
)
|
|
except Exception as e:
|
|
shared.log.error(f'video load: module=pipe repo="{selected.repo}" cls={selected.repo_cls.__name__} {e}')
|
|
errors.display(e, 'video')
|
|
|
|
if shared.sd_model is None:
|
|
msg = f'Video load: model="{selected.name}" failed'
|
|
shared.log.error(msg)
|
|
return msg
|
|
|
|
t1 = time.time()
|
|
if shared.sd_model.__class__.__name__.startswith("LTX"):
|
|
shared.sd_model.scheduler.config.use_dynamic_shifting = False
|
|
shared.sd_model.default_scheduler = copy.deepcopy(shared.sd_model.scheduler) if hasattr(shared.sd_model, "scheduler") else None
|
|
shared.sd_model.sd_checkpoint_info = sd_checkpoint.CheckpointInfo(selected.repo)
|
|
shared.sd_model.sd_model_hash = None
|
|
sd_models.set_diffuser_options(shared.sd_model, offload=False)
|
|
|
|
decode, text, image, slicing, tiling, framewise = False, False, False, False, False, False
|
|
if selected.vae_hijack and hasattr(shared.sd_model, 'vae') and hasattr(shared.sd_model.vae, 'decode'):
|
|
sd_hijack_vae.init_hijack(shared.sd_model)
|
|
decode = True
|
|
if selected.te_hijack and hasattr(shared.sd_model, 'encode_prompt'):
|
|
sd_hijack_te.init_hijack(shared.sd_model)
|
|
text = True
|
|
if selected.image_hijack and hasattr(shared.sd_model, 'encode_image'):
|
|
shared.sd_model.orig_encode_image = shared.sd_model.encode_image
|
|
shared.sd_model.encode_image = video_utils.hijack_encode_image
|
|
image = True
|
|
if hasattr(shared.sd_model, 'vae') and hasattr(shared.sd_model.vae, 'use_framewise_decoding'):
|
|
shared.sd_model.vae.use_framewise_decoding = True
|
|
framewise = True
|
|
if hasattr(shared.sd_model, 'vae') and hasattr(shared.sd_model.vae, 'enable_slicing'):
|
|
shared.sd_model.vae.enable_slicing()
|
|
slicing = True
|
|
if hasattr(shared.sd_model, 'vae') and hasattr(shared.sd_model.vae, 'enable_tiling'):
|
|
shared.sd_model.vae.enable_tiling()
|
|
tiling = True
|
|
if hasattr(shared.sd_model, "set_progress_bar_config"):
|
|
shared.sd_model.set_progress_bar_config(bar_format='Progress {rate_fmt}{postfix} {bar} {percentage:3.0f}% {n_fmt}/{total_fmt} {elapsed} {remaining} ' + '\x1b[38;5;71m', ncols=80, colour='#327fba')
|
|
|
|
shared.sd_model = model_quant.do_post_load_quant(shared.sd_model, allow=False)
|
|
sd_models.set_diffuser_offload(shared.sd_model)
|
|
|
|
loaded_model = selected.name
|
|
msg = f'Video load: cls={shared.sd_model.__class__.__name__} model="{selected.name}" time={t1-t0:.2f}'
|
|
shared.log.info(msg)
|
|
shared.log.debug(f'Video hijacks: decode={decode} text={text} image={image} slicing={slicing} tiling={tiling} framewise={framewise}')
|
|
shared.state.end(jobid)
|
|
return msg
|
|
|
|
|
|
def load_upscale_vae():
|
|
if not hasattr(shared.sd_model, 'vae'):
|
|
return
|
|
if hasattr(shared.sd_model.vae, '_asymmetric_upscale_vae'):
|
|
return # already loaded
|
|
if shared.sd_model.vae.__class__.__name__ != 'AutoencoderKLWan':
|
|
shared.log.warning('Video decode: upscale VAE unsupported')
|
|
return
|
|
|
|
repo_id = 'spacepxl/Wan2.1-VAE-upscale2x'
|
|
subfolder = "diffusers/Wan2.1_VAE_upscale2x_imageonly_real_v1"
|
|
vae_decode = diffusers.AutoencoderKLWan.from_pretrained(repo_id, subfolder=subfolder, cache_dir=shared.opts.hfcache_dir)
|
|
vae_decode.requires_grad_(False)
|
|
vae_decode = vae_decode.to(device=devices.device, dtype=devices.dtype)
|
|
vae_decode.eval()
|
|
shared.log.debug(f'Decode: load="{repo_id}"')
|
|
shared.sd_model.orig_vae = shared.sd_model.vae
|
|
shared.sd_model.vae = vae_decode
|
|
shared.sd_model.vae._asymmetric_upscale_vae = True # pylint: disable=protected-access
|
|
sd_hijack_vae.init_hijack(shared.sd_model)
|
|
sd_models.apply_balanced_offload(shared.sd_model, force=True) # reapply offload
|