1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00

support tiny vae

This commit is contained in:
kijai
2025-03-15 00:48:55 +02:00
parent 84a26d30f9
commit ce79176646
4 changed files with 537 additions and 12 deletions

189
latent_preview.py Normal file
View File

@@ -0,0 +1,189 @@
import torch
from PIL import Image
from comfy.cli_args import args, LatentPreviewMethod
from comfy.taesd.taesd import TAESD
import comfy.model_management
import folder_paths
import comfy.utils
import logging
import os
from .taehv import TAEHV
MAX_PREVIEW_RESOLUTION = args.preview_size
def preview_to_image(latent_image):
print("latent_image shape: ", latent_image.shape)#torch.Size([60, 104, 3])
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
)
if comfy.model_management.directml_enabled:
latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
return Image.fromarray(latents_ubyte.numpy())
class LatentPreviewer:
def decode_latent_to_preview(self, x0):
pass
def decode_latent_to_preview_image(self, preview_format, x0):
preview_image = self.decode_latent_to_preview(x0)
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
class TAESDPreviewerImpl(LatentPreviewer):
def __init__(self, taesd):
self.taesd = taesd
# def decode_latent_to_preview(self, x0):
# #x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
# print("x0 shape: ", x0.shape) #torch.Size([5, 16, 60, 104])
# x0 = x0.unsqueeze(0)
# print("x0 shape: ", x0.shape) #torch.Size([5, 16, 60, 104])
# x_sample = self.taesd.decode_video(x0, parallel=False)[0].permute(0, 2, 3, 1)[0]
# print("x_sample shape: ", x_sample.shape)
# return preview_to_image(x_sample)
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None
if latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
def decode_latent_to_preview(self, x0):
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
if x0.ndim == 5:
x0 = x0[0, :, 0]
else:
x0 = x0[0]
latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
return preview_to_image(latent_image)
def get_previewer(device, latent_format):
previewer = None
method = args.preview_method
if method != LatentPreviewMethod.NoPreviews:
# TODO previewer methods
taesd_decoder_path = None
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB
if method == LatentPreviewMethod.TAESD:
taehv_path = os.path.join(folder_paths.models_dir, "vae_approx", "taew2_1.safetensors")
if not os.path.exists(taehv_path):
raise RuntimeError(f"Could not find {taehv_path}")
taew_sd = comfy.utils.load_torch_file(taehv_path)
taesd = TAEHV(taew_sd).to(device)
previewer = TAESDPreviewerImpl(taesd)
previewer = WrappedPreviewer(previewer)
if previewer is None:
if latent_format.latent_rgb_factors is not None:
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
return previewer
def prepare_callback(model, steps, x0_output_dict=None):
preview_format = "JPEG"
if preview_format not in ["JPEG", "PNG"]:
preview_format = "JPEG"
previewer = get_previewer(model.load_device, model.model.latent_format)
print("previewer: ", previewer)
pbar = comfy.utils.ProgressBar(steps)
def callback(step, x0, x, total_steps):
if x0_output_dict is not None:
x0_output_dict["x0"] = x0
preview_bytes = None
if previewer:
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback
#borrowed VideoHelperSuite https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/blob/main/videohelpersuite/latent_preview.py
import server
from threading import Thread
import torch.nn.functional as F
import io
import time
serv = server.PromptServer.instance
class WrappedPreviewer(LatentPreviewer):
def __init__(self, previewer, rate=16):
self.first_preview = True
self.last_time = 0
self.c_index = 0
self.rate = rate
if hasattr(previewer, 'taesd'):
self.taesd = previewer.taesd
elif hasattr(previewer, 'latent_rgb_factors'):
self.latent_rgb_factors = previewer.latent_rgb_factors
self.latent_rgb_factors_bias = previewer.latent_rgb_factors_bias
else:
raise Exception('Unsupported preview type for VHS animated previews')
def decode_latent_to_preview_image(self, preview_format, x0):
if x0.ndim == 5:
#Keep batch major
x0 = x0.movedim(2,1)
x0 = x0.reshape((-1,)+x0.shape[-3:])
num_images = x0.size(0)
new_time = time.time()
num_previews = int((new_time - self.last_time) * self.rate)
self.last_time = self.last_time + num_previews/self.rate
if num_previews > num_images:
num_previews = num_images
elif num_previews <= 0:
return None
if self.first_preview:
self.first_preview = False
serv.send_sync('VHS_latentpreview', {'length':num_images, 'rate': self.rate})
self.last_time = new_time + 1/self.rate
if self.c_index + num_previews > num_images:
x0 = x0.roll(-self.c_index, 0)[:num_previews]
else:
x0 = x0[self.c_index:self.c_index + num_previews]
Thread(target=self.process_previews, args=(x0, self.c_index,
num_images)).run()
self.c_index = (self.c_index + num_previews) % num_images
return None
def process_previews(self, image_tensor, ind, leng):
max_size = 256
image_tensor = self.decode_latent_to_preview(image_tensor)
if image_tensor.size(1) > max_size or image_tensor.size(2) > max_size:
image_tensor = image_tensor.movedim(-1,0)
if image_tensor.size(2) < image_tensor.size(3):
height = (max_size * image_tensor.size(2)) // image_tensor.size(3)
image_tensor = F.interpolate(image_tensor, (height,max_size), mode='bilinear')
else:
width = (max_size * image_tensor.size(3)) // image_tensor.size(2)
image_tensor = F.interpolate(image_tensor, (max_size, width), mode='bilinear')
image_tensor = image_tensor.movedim(0,-1)
previews_ubyte = (image_tensor.clamp(0, 1)
.mul(0xFF) # to 0..255
).to(device="cpu", dtype=torch.uint8)
for preview in previews_ubyte:
i = Image.fromarray(preview.numpy())
message = io.BytesIO()
message.write((1).to_bytes(length=4, byteorder='big')*2)
message.write(ind.to_bytes(length=4, byteorder='big'))
i.save(message, format="JPEG", quality=95, compress_level=1)
#NOTE: send sync already uses call_soon_threadsafe
serv.send_sync(server.BinaryEventTypes.PREVIEW_IMAGE,
message.getvalue(), serv.client_id)
ind = (ind + 1) % leng
def decode_latent_to_preview(self, x0):
x0 = x0.unsqueeze(0)
x_sample = self.taesd.decode_video(x0, parallel=False)[0].permute(0, 2, 3, 1)
return x_sample

View File

@@ -15,6 +15,7 @@ from .wanvideo.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight, set_num_frames
from .taehv import TAEHV
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
@@ -678,6 +679,42 @@ class WanVideoVAELoader:
return (vae,)
class WanVideoTinyVAELoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (folder_paths.get_filename_list("vae_approx"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae_approx'"}),
},
"optional": {
"precision": (["fp16", "fp32", "bf16"],
{"default": "fp16"}
),
}
}
RETURN_TYPES = ("WANVAE",)
RETURN_NAMES = ("vae", )
FUNCTION = "loadmodel"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "Loads Hunyuan VAE model from 'ComfyUI/models/vae'"
def loadmodel(self, model_name, precision):
from .taehv import TAEHV
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
model_path = folder_paths.get_full_path("vae_approx", model_name)
vae_sd = load_torch_file(model_path, safe_load=True)
vae = TAEHV(vae_sd)
vae.to(device = offload_device, dtype = dtype)
return (vae,)
class WanVideoTorchCompileSettings:
@@ -1403,7 +1440,7 @@ class WanVideoSampler:
pbar = ProgressBar(steps)
from latent_preview import prepare_callback
from .latent_preview import prepare_callback
callback = prepare_callback(patcher, steps)
#blockswap init
@@ -1778,9 +1815,13 @@ class WanVideoSampler:
to_decode = self.previous_noise_pred_context[:,-1,:, :].unsqueeze(1).unsqueeze(0).to(context_vae.dtype)
#to_decode = to_decode.permute(0, 1, 3, 2)
#print("to_decode.shape", to_decode.shape)
image = context_vae.decode(to_decode, device=device, tiled=False)[0]
if isinstance(context_vae, TAEHV):
image = context_vae.decode_video(to_decode.permute(0, 2, 1, 3, 4), parallel=False)[0].permute(1, 0, 2, 3)
image = context_vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False).permute(0, 2, 1, 3, 4)
else:
image = context_vae.decode(to_decode, device=device, tiled=False)[0]
image = context_vae.encode(image.unsqueeze(0).to(context_vae.dtype), device=device, tiled=False)
#print("decoded image.shape", image.shape) #torch.Size([3, 37, 832, 480])
image = context_vae.encode(image.unsqueeze(0).to(context_vae.dtype), device=device, tiled=False)
#print("encoded image.shape", image.shape)
#partial_img_emb[:, 0, :, :] = image[0][:,0,:,:]
#print("partial_img_emb.shape", partial_img_emb.shape)
@@ -1934,14 +1975,17 @@ class WanVideoDecode:
mm.soft_empty_cache()
image = vae.decode(latents, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))[0]
print(image.shape)
print(image.min(), image.max())
if isinstance(vae, TAEHV):
image = vae.decode_video(latents.permute(0, 2, 1, 3, 4))[0].permute(1, 0, 2, 3)
else:
image = vae.decode(latents, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))[0]
vae.model.clear_cache()
image = (image - image.min()) / (image.max() - image.min())
vae.to(offload_device)
vae.model.clear_cache()
mm.soft_empty_cache()
image = (image - image.min()) / (image.max() - image.min())
image = torch.clamp(image, 0.0, 1.0)
image = image.permute(1, 2, 3, 0).cpu().float()
@@ -1978,16 +2022,21 @@ class WanVideoEncode:
vae.to(device)
image = (image.clone() * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
image = (image.clone()).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
if noise_aug_strength > 0.0:
image = add_noise_to_reference_video(image, ratio=noise_aug_strength)
latents = vae.encode(image, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))
if isinstance(vae, TAEHV):
latents = vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False)# B, T, C, H, W
latents = latents.permute(0, 2, 1, 3, 4)
else:
latents = vae.encode(image * 2.0 - 1.0, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))
vae.model.clear_cache()
if latent_strength != 1.0:
latents *= latent_strength
vae.to(offload_device)
vae.model.clear_cache()
mm.soft_empty_cache()
print("encoded latents shape",latents.shape)
@@ -2121,6 +2170,7 @@ NODE_CLASS_MAPPINGS = {
"WanVideoFlowEdit": WanVideoFlowEdit,
"WanVideoControlEmbeds": WanVideoControlEmbeds,
"WanVideoSLG": WanVideoSLG,
"WanVideoTinyVAELoader": WanVideoTinyVAELoader,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoSampler": "WanVideo Sampler",
@@ -2147,4 +2197,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoFlowEdit": "WanVideo FlowEdit",
"WanVideoControlEmbeds": "WanVideo Control Embeds",
"WanVideoSLG": "WanVideo SLG",
"WanVideoTinyVAELoader": "WanVideo Tiny VAE Loader",
}

1
taehv/__init__.py Normal file
View File

@@ -0,0 +1 @@
from .taehv import TAEHV

284
taehv/taehv.py Normal file
View File

@@ -0,0 +1,284 @@
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f*stride,n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N*T, C, H, W)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
# TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
# need to fix :(
out = []
# iterate over input timesteps and also iterate over blocks.
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
# it's actually a ***graph traversal*** problem! so let's make a queue
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
# in addition to manually managing our queue, we also need to manually manage our progressbar.
# we'll update it for every source node that we consume.
progress_bar = tqdm(range(T), disable=not show_progress_bar)
# we'll also need a separate addressable memory per node as well
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
# new source node consumed
progress_bar.update(1)
if i == len(model):
# reached end of the graph, append result to output list
out.append(xt)
else:
# fetch the block to process
b = model[i]
if isinstance(b, MemBlock):
# mem blocks are simple since we're visiting the graph in causal order
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_new, i+1))
elif isinstance(b, TPool):
# pool blocks are miserable
if mem[i] is None:
mem[i] = [] # pool memory is itself a queue of inputs to pool
mem[i].append(xt)
if len(mem[i]) > b.stride:
# pool mem is in invalid state, we should have pooled before this
raise ValueError("???")
elif len(mem[i]) < b.stride:
# pool mem is not yet full, go back to processing the work queue
pass
else:
# pool mem is ready, run the pool block
N, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(N*b.stride, C, H, W))
# reset the pool mem
mem[i] = []
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i+1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
# each tgrow has multiple successor nodes
for xt_next in reversed(xt.view(N, b.stride*C, H, W).chunk(b.stride, 1)):
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_next, i+1))
else:
# normal block with no funny business
xt = b(xt)
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i+1))
progress_bar.close()
x = torch.stack(out, 1)
return x
class TAEHV(nn.Module):
latent_channels = 16
image_channels = 3
def __init__(self, state_dict, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
"""Initialize pretrained TAEHV from the given checkpoint.
Arg:
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
"""
super().__init__()
self.encoder = nn.Sequential(
conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
conv(64, TAEHV.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
)
if state_dict is not None:
self.load_state_dict(self.patch_tgrow_layers(state_dict))
self.dtype = torch.float16
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed.
Args:
sd: state dict to patch
"""
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
if isinstance(layer, TGrow):
key = f"decoder.{i}.conv.weight"
if sd[key].shape[0] > new_sd[key].shape[0]:
# take the last-timestep output channels
sd[key] = sd[key][-new_sd[key].shape[0]:]
return sd
def encode_video(self, x, parallel=True, show_progress_bar=True):
"""Encode a sequence of frames.
Args:
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW latent tensor with ~Gaussian values.
"""
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
def decode_video(self, x, parallel=True, show_progress_bar=True):
"""Decode a sequence of frames.
Args:
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW RGB tensor with ~[0, 1] values.
"""
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
return x[:, self.frames_to_trim:]
def forward(self, x):
return self.c(x)
@torch.no_grad()
def main():
"""Run TAEHV roundtrip reconstruction on the given video paths."""
import sys
import cv2 # no highly esteemed deed is commemorated here
class VideoTensorReader:
def __init__(self, video_file_path):
self.cap = cv2.VideoCapture(video_file_path)
assert self.cap.isOpened(), f"Could not load {video_file_path}"
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
def __iter__(self):
return self
def __next__(self):
ret, frame = self.cap.read()
if not ret:
self.cap.release()
raise StopIteration # End of video or error
return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
class VideoTensorWriter:
def __init__(self, video_file_path, width_height, fps=30):
self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height)
assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
def write(self, frame_tensor):
assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
def __del__(self):
if hasattr(self, 'writer'): self.writer.release()
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
dtype = torch.float16
print("Using device", dev, "and dtype", dtype)
taehv = TAEHV().to(dev, dtype)
for video_path in sys.argv[1:]:
print(f"Processing {video_path}...")
video_in = VideoTensorReader(video_path)
video = torch.stack(list(video_in), 0)[None]
vid_dev = video.to(dev, dtype).div_(255.0)
# convert to device tensor
if video.numel() < 100_000_000:
print(f" {video_path} seems small enough, will process all frames in parallel")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc)
print(f" Decoded {video_path} -> {vid_dec.shape}")
else:
print(f" {video_path} seems large, will process each frame sequentially")
# convert to device tensor
vid_enc = taehv.encode_video(vid_dev, parallel=False)
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
vid_dec = taehv.decode_video(vid_enc, parallel=False)
print(f" Decoded {video_path} -> {vid_dec.shape}")
video_out_path = video_path + ".reconstructed_by_taehv.mp4"
video_out = VideoTensorWriter(video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
video_out.write(frame)
print(f" Saved to {video_out_path}")
if __name__ == "__main__":
main()