1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/sd_vae_stablecascade.py
Vladimir Mandic c608674fb0 styles support parsed and upparsed save and apply
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-02-11 11:33:37 -05:00

91 lines
3.5 KiB
Python

import os
from torch import nn
import safetensors
from modules import devices, paths
preview_model = None
dtype = devices.dtype_vae
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
# https://github.com/Stability-AI/StableCascade/blob/master/modules/previewer.py
class Previewer(nn.Module):
def __init__(self, c_in=16, c_hidden=512, c_out=3):
super().__init__()
self.blocks = nn.Sequential(
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden),
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
)
def forward(self, x):
return self.blocks(x)
def download_model(model_path):
model_url = 'https://huggingface.co/stabilityai/stable-cascade/resolve/main/previewer.safetensors?download=true'
if not os.path.exists(model_path):
import torch
from installer import log
os.makedirs(os.path.dirname(model_path), exist_ok=True)
log.info(f'Downloading Stable Cascade previewer: {model_path}')
torch.hub.download_url_to_file(model_url, model_path)
def load_model(model_path):
checkpoint = {}
with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
return checkpoint
def decode(latents):
from modules import shared
global preview_model # pylint: disable=global-statement
if preview_model is None:
model_path = os.path.join(paths.models_path, "VAE-approx", "sd_cascade_previewer.safetensors")
download_model(model_path)
if os.path.exists(model_path):
preview_model = Previewer()
previewer_checkpoint = load_model(model_path)
preview_model.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
preview_model.eval().requires_grad_(False).to(devices.device, dtype)
del previewer_checkpoint
shared.log.info(f"Load Stable Cascade previewer: model={model_path}")
try:
with devices.inference_context():
latents = latents.detach().clone().unsqueeze(0).to(devices.device, dtype)
image = preview_model(latents)[0].clamp(0, 1).float()
return image
except Exception as e:
shared.log.error(f'Stable Cascade previewer: {e}')
return latents