mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
91 lines
3.5 KiB
Python
91 lines
3.5 KiB
Python
import os
|
|
from torch import nn
|
|
import safetensors
|
|
from modules import devices, paths
|
|
|
|
preview_model = None
|
|
dtype = devices.dtype_vae
|
|
|
|
# Fast Decoder for Stage C latents. E.g. 16 x 24 x 24 -> 3 x 192 x 192
|
|
# https://github.com/Stability-AI/StableCascade/blob/master/modules/previewer.py
|
|
class Previewer(nn.Module):
|
|
def __init__(self, c_in=16, c_hidden=512, c_out=3):
|
|
super().__init__()
|
|
self.blocks = nn.Sequential(
|
|
nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
|
|
nn.GELU(),
|
|
nn.BatchNorm2d(c_hidden),
|
|
|
|
nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
|
|
nn.GELU(),
|
|
nn.BatchNorm2d(c_hidden),
|
|
|
|
nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
|
|
nn.GELU(),
|
|
nn.BatchNorm2d(c_hidden // 2),
|
|
|
|
nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
|
|
nn.GELU(),
|
|
nn.BatchNorm2d(c_hidden // 2),
|
|
|
|
nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
|
|
nn.GELU(),
|
|
nn.BatchNorm2d(c_hidden // 4),
|
|
|
|
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
|
nn.GELU(),
|
|
nn.BatchNorm2d(c_hidden // 4),
|
|
|
|
nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
|
|
nn.GELU(),
|
|
nn.BatchNorm2d(c_hidden // 4),
|
|
|
|
nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
|
|
nn.GELU(),
|
|
nn.BatchNorm2d(c_hidden // 4),
|
|
|
|
nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.blocks(x)
|
|
|
|
|
|
def download_model(model_path):
|
|
model_url = 'https://huggingface.co/stabilityai/stable-cascade/resolve/main/previewer.safetensors?download=true'
|
|
if not os.path.exists(model_path):
|
|
import torch
|
|
from installer import log
|
|
os.makedirs(os.path.dirname(model_path), exist_ok=True)
|
|
log.info(f'Downloading Stable Cascade previewer: {model_path}')
|
|
torch.hub.download_url_to_file(model_url, model_path)
|
|
|
|
def load_model(model_path):
|
|
checkpoint = {}
|
|
with safetensors.safe_open(model_path, framework="pt", device="cpu") as f:
|
|
for key in f.keys():
|
|
checkpoint[key] = f.get_tensor(key)
|
|
return checkpoint
|
|
|
|
def decode(latents):
|
|
from modules import shared
|
|
global preview_model # pylint: disable=global-statement
|
|
if preview_model is None:
|
|
model_path = os.path.join(paths.models_path, "VAE-approx", "sd_cascade_previewer.safetensors")
|
|
download_model(model_path)
|
|
if os.path.exists(model_path):
|
|
preview_model = Previewer()
|
|
previewer_checkpoint = load_model(model_path)
|
|
preview_model.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
|
|
preview_model.eval().requires_grad_(False).to(devices.device, dtype)
|
|
del previewer_checkpoint
|
|
shared.log.info(f"Load Stable Cascade previewer: model={model_path}")
|
|
try:
|
|
with devices.inference_context():
|
|
latents = latents.detach().clone().unsqueeze(0).to(devices.device, dtype)
|
|
image = preview_model(latents)[0].clamp(0, 1).float()
|
|
return image
|
|
except Exception as e:
|
|
shared.log.error(f'Stable Cascade previewer: {e}')
|
|
return latents
|