1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/postprocess/seedvr_model.py
Vladimir Mandic 85a58ed5bf seedvr enable quant
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-10-14 10:23:12 -04:00

172 lines
7.4 KiB
Python

import time
import random
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import ToPILImage
from modules import devices
from modules.shared import opts, log
from modules.upscaler import Upscaler, UpscalerData
MODELS_MAP = {
"SeedVR2 3B": "seedvr2_ema_3b_fp16.safetensors",
"SeedVR2 7B": "seedvr2_ema_7b_fp16.safetensors",
"SeedVR2 7B Sharp": "seedvr2_ema_7b_sharp_fp16.safetensors",
}
to_pil = ToPILImage()
class UpscalerSeedVR(Upscaler):
def __init__(self, dirname=None):
self.name = "SeedVR2"
super().__init__()
self.scalers = [
UpscalerData(name="SeedVR2 3B", path=None, upscaler=self, model=None, scale=1),
UpscalerData(name="SeedVR2 7B", path=None, upscaler=self, model=None, scale=1),
UpscalerData(name="SeedVR2 7B Sharp", path=None, upscaler=self, model=None, scale=1),
]
self.model = None
self.model_loaded = None
def load_model(self, path: str):
model_name = MODELS_MAP.get(path, None)
if (self.model is None) or (self.model_loaded != model_name):
log.debug(f'Upscaler loading: name="{self.name}" model="{model_name}"')
t0 = time.time()
from modules.seedvr.src.core.model_manager import configure_runner
from modules.seedvr.src.core import generation
self.model = configure_runner(
model_name=model_name,
cache_dir=opts.hfcache_dir,
device=devices.device,
dtype=devices.dtype,
)
self.model_loaded = model_name
self.model.dit.device = devices.device
self.model.dit.dtype = devices.dtype
self.model.vae_encode = self.vae_encode
self.model.vae_decode = self.vae_decode
self.model.model_step = generation.generation_step
generation.generation_step = self.model_step
self.model._internal_dict = {
'dit': self.model.dit,
'vae': self.model.vae,
}
t1 = time.time()
self.model.dit.config = self.model.config.dit
self.model.vae.tile_sample_min_size = 1024
self.model.vae.tile_latent_min_size = 128
from modules.model_quant import do_post_load_quant
self.model = do_post_load_quant(self.model, allow=True)
# from modules.sd_offload import set_diffuser_offload
# set_diffuser_offload(self.model)
log.info(f'Upscaler loaded: name="{self.name}" model="{model_name}" time={t1 - t0:.2f}')
def vae_encode(self, samples):
log.debug(f'Upscaler encode: samples={samples[0].shape if len(samples) > 0 else None} tile={self.model.vae.tile_sample_min_size} overlap={self.model.vae.tile_overlap_factor}')
latents = []
if len(samples) == 0:
return latents
self.model.dit = self.model.dit.to(device="cpu")
self.model.vae = self.model.vae.to(device=self.device)
devices.torch_gc()
from einops import rearrange
from modules.seedvr.src.optimization import memory_manager
memory_manager.clear_rope_cache(self.model)
scale = self.model.config.vae.scaling_factor
shift = self.model.config.vae.get("shifting_factor", 0.0)
batches = [sample.unsqueeze(0) for sample in samples]
for sample in batches:
sample = sample.to(self.device, self.model.vae.dtype)
sample = self.model.vae.preprocess(sample)
latent = self.model.vae.encode(sample).latent
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
latent = rearrange(latent, "b c ... -> b ... c")
latent = (latent - shift) * scale
latents.append(latent)
latents = [latent.squeeze(0) for latent in latents]
self.model.vae = self.model.vae.to(device="cpu")
devices.torch_gc()
return latents
def vae_decode(self, latents, target_dtype: torch.dtype = None):
log.debug(f'Upscaler decode: latents={latents[0].shape if len(latents) > 0 else None} tile={self.model.vae.tile_latent_min_size} overlap={self.model.vae.tile_overlap_factor}')
samples = []
if len(latents) == 0:
return samples
from einops import rearrange
from modules.seedvr.src.optimization import memory_manager
memory_manager.clear_rope_cache(self.model)
self.model.dit = self.model.dit.to(device="cpu")
self.model.vae = self.model.vae.to(device=self.device)
devices.torch_gc()
scale = self.model.config.vae.scaling_factor
shift = self.model.config.vae.get("shifting_factor", 0.0)
latents = [latent.unsqueeze(0) for latent in latents]
with devices.inference_context():
for _i, latent in enumerate(latents):
latent = latent.to(self.device, self.model.vae.dtype)
latent = latent / scale + shift
latent = rearrange(latent, "b ... c -> b c ...")
latent = latent.squeeze(2)
sample = self.model.vae.decode(latent).sample
sample = self.model.vae.postprocess(sample)
samples.append(sample)
samples = [sample.squeeze(0) for sample in samples]
self.model.vae = self.model.vae.to(device="cpu")
devices.torch_gc()
return samples
def model_step(self, *args, **kwargs):
from modules.seedvr.src.optimization import memory_manager
self.model.vae = self.model.vae.to(device="cpu")
self.model.dit = self.model.dit.to(device=self.device)
devices.torch_gc()
log.debug(f'Upscaler inference: args={len(args)} kwargs={list(kwargs.keys())}')
memory_manager.preinitialize_rope_cache(self.model)
with devices.inference_context():
result = self.model.model_step(*args, **kwargs)
self.model.dit = self.model.dit.to(device="cpu")
devices.torch_gc()
return result
def do_upscale(self, img: Image.Image, selected_file):
self.load_model(selected_file)
if self.model is None:
return img
from modules.seedvr.src.core import generation
width = int(self.scale * img.width) // 8 * 8
image_tensor = np.array(img)
image_tensor = torch.from_numpy(image_tensor).to(device=devices.device, dtype=devices.dtype).unsqueeze(0) / 255.0
random.seed()
seed = int(random.randrange(4294967294))
t0 = time.time()
with devices.inference_context():
result_tensor = generation.generation_loop(
runner=self.model,
images=image_tensor,
cfg_scale=opts.seedvt_cfg_scale,
seed=seed,
res_w=width,
batch_size=1,
temporal_overlap=0,
device=devices.device,
)
t1 = time.time()
log.info(f'Upscaler: type="{self.name}" model="{selected_file}" scale={self.scale} cfg={opts.seedvt_cfg_scale} seed={seed} time={t1 - t0:.2f}')
img = to_pil(result_tensor.squeeze().permute((2, 0, 1)))
if opts.upscaler_unload:
self.model.dit = None
self.model.vae = None
self.model.cache = None
self.model = None
log.debug(f'Upscaler unload: type="{self.name}" model="{selected_file}"')
devices.torch_gc(force=True)
return img