1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/taesd/taem1.py
vladmandic cc0b0e8e3d cleanup todo
Signed-off-by: vladmandic <mandic00@live.com>
2026-01-19 11:10:05 +01:00

207 lines
9.4 KiB
Python

#!/usr/bin/env python3
"""
Tiny AutoEncoder for Mochi 1
(DNN for encoding / decoding videos to Mochi 1's latent space)
"""
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f*stride,n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEM1 operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N*T, C, H, W)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
out = []
# iterate over input timesteps and also iterate over blocks.
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
# it's actually a ***graph traversal*** problem! so let's make a queue
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
# in addition to manually managing our queue, we also need to manually manage our progressbar.
# we'll update it for every source node that we consume.
progress_bar = tqdm(range(T), disable=not show_progress_bar)
# we'll also need a separate addressable memory per node as well
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
# new source node consumed
progress_bar.update(1)
if i == len(model):
# reached end of the graph, append result to output list
out.append(xt)
else:
# fetch the block to process
b = model[i]
if isinstance(b, MemBlock):
# mem blocks are simple since we're visiting the graph in causal order
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_new, i+1))
elif isinstance(b, TPool):
# pool blocks are miserable
if mem[i] is None:
mem[i] = [] # pool memory is itself a queue of inputs to pool
mem[i].append(xt)
if len(mem[i]) > b.stride:
# pool mem is in invalid state, we should have pooled before this
raise ValueError("???")
elif len(mem[i]) < b.stride:
# pool mem is not yet full, go back to processing the work queue
pass
else:
# pool mem is ready, run the pool block
N, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(N*b.stride, C, H, W))
# reset the pool mem
mem[i] = []
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i+1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
# each tgrow has multiple successor nodes
for xt_next in reversed(xt.view(N, b.stride*C, H, W).chunk(b.stride, 1)):
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_next, i+1))
else:
# normal block with no funny business
xt = b(xt)
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i+1))
progress_bar.close()
x = torch.stack(out, 1)
return x
class TAEM1(nn.Module):
latent_channels = 12
image_channels = 3
def __init__(self, checkpoint_path="taem1.pth"):
"""Initialize pretrained TAEM1 from the given checkpoints."""
super().__init__()
self.encoder = nn.Sequential(
conv(TAEM1.image_channels, 64), nn.ReLU(inplace=True),
TPool(64, 3), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
conv(64, TAEM1.latent_channels),
)
n_f = [256, 128, 64, 64]
self.decoder = nn.Sequential(
Clamp(), conv(TAEM1.latent_channels, n_f[0]), nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2), TGrow(n_f[1], 2), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2), TGrow(n_f[2], 3), conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True), conv(n_f[3], TAEM1.image_channels),
)
if checkpoint_path is not None:
self.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
def encode_video(self, x, parallel=True, show_progress_bar=True):
"""Encode a sequence of frames.
Args:
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW latent tensor with ~Gaussian values.
"""
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
def decode_video(self, x, parallel=True, show_progress_bar=True):
"""Decode a sequence of frames.
Args:
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW RGB tensor with ~[0, 1] values.
"""
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
# NOTE:
# the Mochi VAE does not preserve shape along the time axis;
# videos are encoded to floor((n_in - 1)/6)+1 latent frames
# (which makes sense, it's stride 6, so 12 -> 2 and 13->3)
# but then they're decoded to only the *minimal* number
# of input frames (3 latents get decoded to 13 frames, not 18)
# in order to achieve the intended causal structure...
# anyway, that's why we have to remove some frames here.
# mochi-VAE does the slicing at each TGrow (save compute/mem?)
# but I think it's basically the same
return x[:, 5:]
def forward(self, x):
return self.c(x)