mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
207 lines
9.4 KiB
Python
207 lines
9.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Tiny AutoEncoder for Mochi 1
|
|
(DNN for encoding / decoding videos to Mochi 1's latent space)
|
|
"""
|
|
from collections import namedtuple
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from tqdm.auto import tqdm
|
|
|
|
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
|
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
|
|
|
def conv(n_in, n_out, **kwargs):
|
|
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
|
|
|
class Clamp(nn.Module):
|
|
def forward(self, x):
|
|
return torch.tanh(x / 3) * 3
|
|
|
|
class MemBlock(nn.Module):
|
|
def __init__(self, n_in, n_out):
|
|
super().__init__()
|
|
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
|
|
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
|
self.act = nn.ReLU(inplace=True)
|
|
def forward(self, x, past):
|
|
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
|
|
|
class TPool(nn.Module):
|
|
def __init__(self, n_f, stride):
|
|
super().__init__()
|
|
self.stride = stride
|
|
self.conv = nn.Conv2d(n_f*stride,n_f, 1, bias=False)
|
|
def forward(self, x):
|
|
_NT, C, H, W = x.shape
|
|
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
|
|
|
class TGrow(nn.Module):
|
|
def __init__(self, n_f, stride):
|
|
super().__init__()
|
|
self.stride = stride
|
|
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
|
|
def forward(self, x):
|
|
_NT, C, H, W = x.shape
|
|
x = self.conv(x)
|
|
return x.reshape(-1, C, H, W)
|
|
|
|
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
|
"""
|
|
Apply a sequential model with memblocks to the given input.
|
|
Args:
|
|
- model: nn.Sequential of blocks to apply
|
|
- x: input data, of dimensions NTCHW
|
|
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
|
|
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
|
|
- show_progress_bar: if True, enables tqdm progressbar display
|
|
|
|
Returns NTCHW tensor of output data.
|
|
"""
|
|
assert x.ndim == 5, f"TAEM1 operates on NTCHW tensors, but got {x.ndim}-dim tensor"
|
|
N, T, C, H, W = x.shape
|
|
if parallel:
|
|
x = x.reshape(N*T, C, H, W)
|
|
# parallel over input timesteps, iterate over blocks
|
|
for b in tqdm(model, disable=not show_progress_bar):
|
|
if isinstance(b, MemBlock):
|
|
NT, C, H, W = x.shape
|
|
T = NT // N
|
|
_x = x.reshape(N, T, C, H, W)
|
|
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
|
|
x = b(x, mem)
|
|
else:
|
|
x = b(x)
|
|
NT, C, H, W = x.shape
|
|
T = NT // N
|
|
x = x.view(N, T, C, H, W)
|
|
else:
|
|
out = []
|
|
# iterate over input timesteps and also iterate over blocks.
|
|
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
|
|
# it's actually a ***graph traversal*** problem! so let's make a queue
|
|
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
|
|
# in addition to manually managing our queue, we also need to manually manage our progressbar.
|
|
# we'll update it for every source node that we consume.
|
|
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
|
# we'll also need a separate addressable memory per node as well
|
|
mem = [None] * len(model)
|
|
while work_queue:
|
|
xt, i = work_queue.pop(0)
|
|
if i == 0:
|
|
# new source node consumed
|
|
progress_bar.update(1)
|
|
if i == len(model):
|
|
# reached end of the graph, append result to output list
|
|
out.append(xt)
|
|
else:
|
|
# fetch the block to process
|
|
b = model[i]
|
|
if isinstance(b, MemBlock):
|
|
# mem blocks are simple since we're visiting the graph in causal order
|
|
if mem[i] is None:
|
|
xt_new = b(xt, xt * 0)
|
|
mem[i] = xt
|
|
else:
|
|
xt_new = b(xt, mem[i])
|
|
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
|
|
# add successor to work queue
|
|
work_queue.insert(0, TWorkItem(xt_new, i+1))
|
|
elif isinstance(b, TPool):
|
|
# pool blocks are miserable
|
|
if mem[i] is None:
|
|
mem[i] = [] # pool memory is itself a queue of inputs to pool
|
|
mem[i].append(xt)
|
|
if len(mem[i]) > b.stride:
|
|
# pool mem is in invalid state, we should have pooled before this
|
|
raise ValueError("???")
|
|
elif len(mem[i]) < b.stride:
|
|
# pool mem is not yet full, go back to processing the work queue
|
|
pass
|
|
else:
|
|
# pool mem is ready, run the pool block
|
|
N, C, H, W = xt.shape
|
|
xt = b(torch.cat(mem[i], 1).view(N*b.stride, C, H, W))
|
|
# reset the pool mem
|
|
mem[i] = []
|
|
# add successor to work queue
|
|
work_queue.insert(0, TWorkItem(xt, i+1))
|
|
elif isinstance(b, TGrow):
|
|
xt = b(xt)
|
|
NT, C, H, W = xt.shape
|
|
# each tgrow has multiple successor nodes
|
|
for xt_next in reversed(xt.view(N, b.stride*C, H, W).chunk(b.stride, 1)):
|
|
# add successor to work queue
|
|
work_queue.insert(0, TWorkItem(xt_next, i+1))
|
|
else:
|
|
# normal block with no funny business
|
|
xt = b(xt)
|
|
# add successor to work queue
|
|
work_queue.insert(0, TWorkItem(xt, i+1))
|
|
progress_bar.close()
|
|
x = torch.stack(out, 1)
|
|
return x
|
|
|
|
class TAEM1(nn.Module):
|
|
latent_channels = 12
|
|
image_channels = 3
|
|
def __init__(self, checkpoint_path="taem1.pth"):
|
|
"""Initialize pretrained TAEM1 from the given checkpoints."""
|
|
super().__init__()
|
|
self.encoder = nn.Sequential(
|
|
conv(TAEM1.image_channels, 64), nn.ReLU(inplace=True),
|
|
TPool(64, 3), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
|
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
|
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
|
conv(64, TAEM1.latent_channels),
|
|
)
|
|
n_f = [256, 128, 64, 64]
|
|
self.decoder = nn.Sequential(
|
|
Clamp(), conv(TAEM1.latent_channels, n_f[0]), nn.ReLU(inplace=True),
|
|
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
|
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2), TGrow(n_f[1], 2), conv(n_f[1], n_f[2], bias=False),
|
|
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2), TGrow(n_f[2], 3), conv(n_f[2], n_f[3], bias=False),
|
|
nn.ReLU(inplace=True), conv(n_f[3], TAEM1.image_channels),
|
|
)
|
|
if checkpoint_path is not None:
|
|
self.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
|
|
|
|
def encode_video(self, x, parallel=True, show_progress_bar=True):
|
|
"""Encode a sequence of frames.
|
|
|
|
Args:
|
|
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
|
|
parallel: if True, all frames will be processed at once.
|
|
(this is faster but may require more memory).
|
|
if False, frames will be processed sequentially.
|
|
Returns NTCHW latent tensor with ~Gaussian values.
|
|
"""
|
|
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
|
|
|
|
def decode_video(self, x, parallel=True, show_progress_bar=True):
|
|
"""Decode a sequence of frames.
|
|
|
|
Args:
|
|
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
|
|
parallel: if True, all frames will be processed at once.
|
|
(this is faster but may require more memory).
|
|
if False, frames will be processed sequentially.
|
|
Returns NTCHW RGB tensor with ~[0, 1] values.
|
|
"""
|
|
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
|
|
# NOTE:
|
|
# the Mochi VAE does not preserve shape along the time axis;
|
|
# videos are encoded to floor((n_in - 1)/6)+1 latent frames
|
|
# (which makes sense, it's stride 6, so 12 -> 2 and 13->3)
|
|
# but then they're decoded to only the *minimal* number
|
|
# of input frames (3 latents get decoded to 13 frames, not 18)
|
|
# in order to achieve the intended causal structure...
|
|
# anyway, that's why we have to remove some frames here.
|
|
# mochi-VAE does the slicing at each TGrow (save compute/mem?)
|
|
# but I think it's basically the same
|
|
return x[:, 5:]
|
|
|
|
def forward(self, x):
|
|
return self.c(x)
|