diff --git a/latent_preview.py b/latent_preview.py new file mode 100644 index 0000000..e0b91ca --- /dev/null +++ b/latent_preview.py @@ -0,0 +1,189 @@ +import torch +from PIL import Image +from comfy.cli_args import args, LatentPreviewMethod +from comfy.taesd.taesd import TAESD +import comfy.model_management +import folder_paths +import comfy.utils +import logging +import os + +from .taehv import TAEHV + +MAX_PREVIEW_RESOLUTION = args.preview_size + +def preview_to_image(latent_image): + print("latent_image shape: ", latent_image.shape)#torch.Size([60, 104, 3]) + latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + ) + if comfy.model_management.directml_enabled: + latents_ubyte = latents_ubyte.to(dtype=torch.uint8) + latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device)) + + return Image.fromarray(latents_ubyte.numpy()) + +class LatentPreviewer: + def decode_latent_to_preview(self, x0): + pass + + def decode_latent_to_preview_image(self, preview_format, x0): + preview_image = self.decode_latent_to_preview(x0) + return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION) + +class TAESDPreviewerImpl(LatentPreviewer): + def __init__(self, taesd): + self.taesd = taesd + + # def decode_latent_to_preview(self, x0): + # #x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2) + # print("x0 shape: ", x0.shape) #torch.Size([5, 16, 60, 104]) + # x0 = x0.unsqueeze(0) + # print("x0 shape: ", x0.shape) #torch.Size([5, 16, 60, 104]) + # x_sample = self.taesd.decode_video(x0, parallel=False)[0].permute(0, 2, 3, 1)[0] + # print("x_sample shape: ", x_sample.shape) + # return preview_to_image(x_sample) + + +class Latent2RGBPreviewer(LatentPreviewer): + def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None): + self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) + self.latent_rgb_factors_bias = None + if latent_rgb_factors_bias is not None: + self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") + + def decode_latent_to_preview(self, x0): + self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) + if self.latent_rgb_factors_bias is not None: + self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) + + if x0.ndim == 5: + x0 = x0[0, :, 0] + else: + x0 = x0[0] + + latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias) + # latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors + + return preview_to_image(latent_image) + + +def get_previewer(device, latent_format): + previewer = None + method = args.preview_method + if method != LatentPreviewMethod.NoPreviews: + # TODO previewer methods + taesd_decoder_path = None + + if method == LatentPreviewMethod.Auto: + method = LatentPreviewMethod.Latent2RGB + + if method == LatentPreviewMethod.TAESD: + taehv_path = os.path.join(folder_paths.models_dir, "vae_approx", "taew2_1.safetensors") + if not os.path.exists(taehv_path): + raise RuntimeError(f"Could not find {taehv_path}") + taew_sd = comfy.utils.load_torch_file(taehv_path) + taesd = TAEHV(taew_sd).to(device) + previewer = TAESDPreviewerImpl(taesd) + previewer = WrappedPreviewer(previewer) + + if previewer is None: + if latent_format.latent_rgb_factors is not None: + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) + return previewer + +def prepare_callback(model, steps, x0_output_dict=None): + preview_format = "JPEG" + if preview_format not in ["JPEG", "PNG"]: + preview_format = "JPEG" + + previewer = get_previewer(model.load_device, model.model.latent_format) + print("previewer: ", previewer) + + pbar = comfy.utils.ProgressBar(steps) + def callback(step, x0, x, total_steps): + if x0_output_dict is not None: + x0_output_dict["x0"] = x0 + + preview_bytes = None + if previewer: + preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) + pbar.update_absolute(step + 1, total_steps, preview_bytes) + return callback + +#borrowed VideoHelperSuite https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/blob/main/videohelpersuite/latent_preview.py +import server +from threading import Thread +import torch.nn.functional as F +import io +import time +serv = server.PromptServer.instance + +class WrappedPreviewer(LatentPreviewer): + def __init__(self, previewer, rate=16): + self.first_preview = True + self.last_time = 0 + self.c_index = 0 + self.rate = rate + if hasattr(previewer, 'taesd'): + self.taesd = previewer.taesd + elif hasattr(previewer, 'latent_rgb_factors'): + self.latent_rgb_factors = previewer.latent_rgb_factors + self.latent_rgb_factors_bias = previewer.latent_rgb_factors_bias + else: + raise Exception('Unsupported preview type for VHS animated previews') + + def decode_latent_to_preview_image(self, preview_format, x0): + if x0.ndim == 5: + #Keep batch major + x0 = x0.movedim(2,1) + x0 = x0.reshape((-1,)+x0.shape[-3:]) + num_images = x0.size(0) + new_time = time.time() + num_previews = int((new_time - self.last_time) * self.rate) + self.last_time = self.last_time + num_previews/self.rate + if num_previews > num_images: + num_previews = num_images + elif num_previews <= 0: + return None + if self.first_preview: + self.first_preview = False + serv.send_sync('VHS_latentpreview', {'length':num_images, 'rate': self.rate}) + self.last_time = new_time + 1/self.rate + if self.c_index + num_previews > num_images: + x0 = x0.roll(-self.c_index, 0)[:num_previews] + else: + x0 = x0[self.c_index:self.c_index + num_previews] + Thread(target=self.process_previews, args=(x0, self.c_index, + num_images)).run() + self.c_index = (self.c_index + num_previews) % num_images + return None + def process_previews(self, image_tensor, ind, leng): + max_size = 256 + image_tensor = self.decode_latent_to_preview(image_tensor) + if image_tensor.size(1) > max_size or image_tensor.size(2) > max_size: + image_tensor = image_tensor.movedim(-1,0) + if image_tensor.size(2) < image_tensor.size(3): + height = (max_size * image_tensor.size(2)) // image_tensor.size(3) + image_tensor = F.interpolate(image_tensor, (height,max_size), mode='bilinear') + else: + width = (max_size * image_tensor.size(3)) // image_tensor.size(2) + image_tensor = F.interpolate(image_tensor, (max_size, width), mode='bilinear') + image_tensor = image_tensor.movedim(0,-1) + previews_ubyte = (image_tensor.clamp(0, 1) + .mul(0xFF) # to 0..255 + ).to(device="cpu", dtype=torch.uint8) + for preview in previews_ubyte: + i = Image.fromarray(preview.numpy()) + message = io.BytesIO() + message.write((1).to_bytes(length=4, byteorder='big')*2) + message.write(ind.to_bytes(length=4, byteorder='big')) + i.save(message, format="JPEG", quality=95, compress_level=1) + #NOTE: send sync already uses call_soon_threadsafe + serv.send_sync(server.BinaryEventTypes.PREVIEW_IMAGE, + message.getvalue(), serv.client_id) + ind = (ind + 1) % leng + def decode_latent_to_preview(self, x0): + x0 = x0.unsqueeze(0) + x_sample = self.taesd.decode_video(x0, parallel=False)[0].permute(0, 2, 3, 1) + return x_sample diff --git a/nodes.py b/nodes.py index a14ad7b..8edafe7 100644 --- a/nodes.py +++ b/nodes.py @@ -15,6 +15,7 @@ from .wanvideo.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight, set_num_frames +from .taehv import TAEHV from accelerate import init_empty_weights from accelerate.utils import set_module_tensor_to_device @@ -678,6 +679,42 @@ class WanVideoVAELoader: return (vae,) +class WanVideoTinyVAELoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model_name": (folder_paths.get_filename_list("vae_approx"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae_approx'"}), + }, + "optional": { + "precision": (["fp16", "fp32", "bf16"], + {"default": "fp16"} + ), + } + } + + RETURN_TYPES = ("WANVAE",) + RETURN_NAMES = ("vae", ) + FUNCTION = "loadmodel" + CATEGORY = "WanVideoWrapper" + DESCRIPTION = "Loads Hunyuan VAE model from 'ComfyUI/models/vae'" + + def loadmodel(self, model_name, precision): + from .taehv import TAEHV + + device = mm.get_torch_device() + offload_device = mm.unet_offload_device() + + dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] + model_path = folder_paths.get_full_path("vae_approx", model_name) + vae_sd = load_torch_file(model_path, safe_load=True) + + vae = TAEHV(vae_sd) + + vae.to(device = offload_device, dtype = dtype) + + return (vae,) + class WanVideoTorchCompileSettings: @@ -1403,7 +1440,7 @@ class WanVideoSampler: pbar = ProgressBar(steps) - from latent_preview import prepare_callback + from .latent_preview import prepare_callback callback = prepare_callback(patcher, steps) #blockswap init @@ -1778,9 +1815,13 @@ class WanVideoSampler: to_decode = self.previous_noise_pred_context[:,-1,:, :].unsqueeze(1).unsqueeze(0).to(context_vae.dtype) #to_decode = to_decode.permute(0, 1, 3, 2) #print("to_decode.shape", to_decode.shape) - image = context_vae.decode(to_decode, device=device, tiled=False)[0] + if isinstance(context_vae, TAEHV): + image = context_vae.decode_video(to_decode.permute(0, 2, 1, 3, 4), parallel=False)[0].permute(1, 0, 2, 3) + image = context_vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False).permute(0, 2, 1, 3, 4) + else: + image = context_vae.decode(to_decode, device=device, tiled=False)[0] + image = context_vae.encode(image.unsqueeze(0).to(context_vae.dtype), device=device, tiled=False) #print("decoded image.shape", image.shape) #torch.Size([3, 37, 832, 480]) - image = context_vae.encode(image.unsqueeze(0).to(context_vae.dtype), device=device, tiled=False) #print("encoded image.shape", image.shape) #partial_img_emb[:, 0, :, :] = image[0][:,0,:,:] #print("partial_img_emb.shape", partial_img_emb.shape) @@ -1934,14 +1975,17 @@ class WanVideoDecode: mm.soft_empty_cache() - image = vae.decode(latents, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))[0] - print(image.shape) - print(image.min(), image.max()) + if isinstance(vae, TAEHV): + image = vae.decode_video(latents.permute(0, 2, 1, 3, 4))[0].permute(1, 0, 2, 3) + else: + image = vae.decode(latents, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))[0] + vae.model.clear_cache() + image = (image - image.min()) / (image.max() - image.min()) vae.to(offload_device) - vae.model.clear_cache() + mm.soft_empty_cache() - image = (image - image.min()) / (image.max() - image.min()) + image = torch.clamp(image, 0.0, 1.0) image = image.permute(1, 2, 3, 0).cpu().float() @@ -1978,16 +2022,21 @@ class WanVideoEncode: vae.to(device) - image = (image.clone() * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W + image = (image.clone()).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W if noise_aug_strength > 0.0: image = add_noise_to_reference_video(image, ratio=noise_aug_strength) - - latents = vae.encode(image, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y)) + + if isinstance(vae, TAEHV): + latents = vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False)# B, T, C, H, W + latents = latents.permute(0, 2, 1, 3, 4) + else: + latents = vae.encode(image * 2.0 - 1.0, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y)) + vae.model.clear_cache() if latent_strength != 1.0: latents *= latent_strength vae.to(offload_device) - vae.model.clear_cache() + mm.soft_empty_cache() print("encoded latents shape",latents.shape) @@ -2121,6 +2170,7 @@ NODE_CLASS_MAPPINGS = { "WanVideoFlowEdit": WanVideoFlowEdit, "WanVideoControlEmbeds": WanVideoControlEmbeds, "WanVideoSLG": WanVideoSLG, + "WanVideoTinyVAELoader": WanVideoTinyVAELoader, } NODE_DISPLAY_NAME_MAPPINGS = { "WanVideoSampler": "WanVideo Sampler", @@ -2147,4 +2197,5 @@ NODE_DISPLAY_NAME_MAPPINGS = { "WanVideoFlowEdit": "WanVideo FlowEdit", "WanVideoControlEmbeds": "WanVideo Control Embeds", "WanVideoSLG": "WanVideo SLG", + "WanVideoTinyVAELoader": "WanVideo Tiny VAE Loader", } diff --git a/taehv/__init__.py b/taehv/__init__.py new file mode 100644 index 0000000..c6e076b --- /dev/null +++ b/taehv/__init__.py @@ -0,0 +1 @@ +from .taehv import TAEHV \ No newline at end of file diff --git a/taehv/taehv.py b/taehv/taehv.py new file mode 100644 index 0000000..d6adf23 --- /dev/null +++ b/taehv/taehv.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +""" +Tiny AutoEncoder for Hunyuan Video +(DNN for encoding / decoding videos to Hunyuan Video's latent space) +""" +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm.auto import tqdm +from collections import namedtuple + +DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) +TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class MemBlock(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.act = nn.ReLU(inplace=True) + def forward(self, x, past): + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + +class TPool(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f*stride,n_f, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + +class TGrow(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + +def apply_model_with_memblocks(model, x, parallel, show_progress_bar): + """ + Apply a sequential model with memblocks to the given input. + Args: + - model: nn.Sequential of blocks to apply + - x: input data, of dimensions NTCHW + - parallel: if True, parallelize over timesteps (fast but uses O(T) memory) + if False, each timestep will be processed sequentially (slow but uses O(1) memory) + - show_progress_bar: if True, enables tqdm progressbar display + + Returns NTCHW tensor of output data. + """ + assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor" + N, T, C, H, W = x.shape + if parallel: + x = x.reshape(N*T, C, H, W) + # parallel over input timesteps, iterate over blocks + for b in tqdm(model, disable=not show_progress_bar): + if isinstance(b, MemBlock): + NT, C, H, W = x.shape + T = NT // N + _x = x.reshape(N, T, C, H, W) + mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) + x = b(x, mem) + else: + x = b(x) + NT, C, H, W = x.shape + T = NT // N + x = x.view(N, T, C, H, W) + else: + # TODO(oboerbohan): at least on macos this still gradually uses more memory during decode... + # need to fix :( + out = [] + # iterate over input timesteps and also iterate over blocks. + # because of the cursed TPool/TGrow blocks, this is not a nested loop, + # it's actually a ***graph traversal*** problem! so let's make a queue + work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))] + # in addition to manually managing our queue, we also need to manually manage our progressbar. + # we'll update it for every source node that we consume. + progress_bar = tqdm(range(T), disable=not show_progress_bar) + # we'll also need a separate addressable memory per node as well + mem = [None] * len(model) + while work_queue: + xt, i = work_queue.pop(0) + if i == 0: + # new source node consumed + progress_bar.update(1) + if i == len(model): + # reached end of the graph, append result to output list + out.append(xt) + else: + # fetch the block to process + b = model[i] + if isinstance(b, MemBlock): + # mem blocks are simple since we're visiting the graph in causal order + if mem[i] is None: + xt_new = b(xt, xt * 0) + mem[i] = xt + else: + xt_new = b(xt, mem[i]) + mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though + # add successor to work queue + work_queue.insert(0, TWorkItem(xt_new, i+1)) + elif isinstance(b, TPool): + # pool blocks are miserable + if mem[i] is None: + mem[i] = [] # pool memory is itself a queue of inputs to pool + mem[i].append(xt) + if len(mem[i]) > b.stride: + # pool mem is in invalid state, we should have pooled before this + raise ValueError("???") + elif len(mem[i]) < b.stride: + # pool mem is not yet full, go back to processing the work queue + pass + else: + # pool mem is ready, run the pool block + N, C, H, W = xt.shape + xt = b(torch.cat(mem[i], 1).view(N*b.stride, C, H, W)) + # reset the pool mem + mem[i] = [] + # add successor to work queue + work_queue.insert(0, TWorkItem(xt, i+1)) + elif isinstance(b, TGrow): + xt = b(xt) + NT, C, H, W = xt.shape + # each tgrow has multiple successor nodes + for xt_next in reversed(xt.view(N, b.stride*C, H, W).chunk(b.stride, 1)): + # add successor to work queue + work_queue.insert(0, TWorkItem(xt_next, i+1)) + else: + # normal block with no funny business + xt = b(xt) + # add successor to work queue + work_queue.insert(0, TWorkItem(xt, i+1)) + progress_bar.close() + x = torch.stack(out, 1) + return x + +class TAEHV(nn.Module): + latent_channels = 16 + image_channels = 3 + def __init__(self, state_dict, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)): + """Initialize pretrained TAEHV from the given checkpoint. + + Arg: + checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1. + decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview. + decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview. + """ + super().__init__() + self.encoder = nn.Sequential( + conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True), + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), + TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64), + conv(64, TAEHV.latent_channels), + ) + n_f = [256, 128, 64, 64] + self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( + Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True), + MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels), + ) + if state_dict is not None: + self.load_state_dict(self.patch_tgrow_layers(state_dict)) + self.dtype = torch.float16 + + def patch_tgrow_layers(self, sd): + """Patch TGrow layers to use a smaller kernel if needed. + + Args: + sd: state dict to patch + """ + new_sd = self.state_dict() + for i, layer in enumerate(self.decoder): + if isinstance(layer, TGrow): + key = f"decoder.{i}.conv.weight" + if sd[key].shape[0] > new_sd[key].shape[0]: + # take the last-timestep output channels + sd[key] = sd[key][-new_sd[key].shape[0]:] + return sd + + def encode_video(self, x, parallel=True, show_progress_bar=True): + """Encode a sequence of frames. + + Args: + x: input NTCHW RGB (C=3) tensor with values in [0, 1]. + parallel: if True, all frames will be processed at once. + (this is faster but may require more memory). + if False, frames will be processed sequentially. + Returns NTCHW latent tensor with ~Gaussian values. + """ + return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar) + + def decode_video(self, x, parallel=True, show_progress_bar=True): + """Decode a sequence of frames. + + Args: + x: input NTCHW latent (C=12) tensor with ~Gaussian values. + parallel: if True, all frames will be processed at once. + (this is faster but may require more memory). + if False, frames will be processed sequentially. + Returns NTCHW RGB tensor with ~[0, 1] values. + """ + x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar) + return x[:, self.frames_to_trim:] + + def forward(self, x): + return self.c(x) + +@torch.no_grad() +def main(): + """Run TAEHV roundtrip reconstruction on the given video paths.""" + import sys + import cv2 # no highly esteemed deed is commemorated here + + class VideoTensorReader: + def __init__(self, video_file_path): + self.cap = cv2.VideoCapture(video_file_path) + assert self.cap.isOpened(), f"Could not load {video_file_path}" + self.fps = self.cap.get(cv2.CAP_PROP_FPS) + def __iter__(self): + return self + def __next__(self): + ret, frame = self.cap.read() + if not ret: + self.cap.release() + raise StopIteration # End of video or error + return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW + + class VideoTensorWriter: + def __init__(self, video_file_path, width_height, fps=30): + self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height) + assert self.writer.isOpened(), f"Could not create writer for {video_file_path}" + def write(self, frame_tensor): + assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??" + self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC + def __del__(self): + if hasattr(self, 'writer'): self.writer.release() + + dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + dtype = torch.float16 + print("Using device", dev, "and dtype", dtype) + taehv = TAEHV().to(dev, dtype) + for video_path in sys.argv[1:]: + print(f"Processing {video_path}...") + video_in = VideoTensorReader(video_path) + video = torch.stack(list(video_in), 0)[None] + vid_dev = video.to(dev, dtype).div_(255.0) + # convert to device tensor + if video.numel() < 100_000_000: + print(f" {video_path} seems small enough, will process all frames in parallel") + # convert to device tensor + vid_enc = taehv.encode_video(vid_dev) + print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...") + vid_dec = taehv.decode_video(vid_enc) + print(f" Decoded {video_path} -> {vid_dec.shape}") + else: + print(f" {video_path} seems large, will process each frame sequentially") + # convert to device tensor + vid_enc = taehv.encode_video(vid_dev, parallel=False) + print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...") + vid_dec = taehv.decode_video(vid_enc, parallel=False) + print(f" Decoded {video_path} -> {vid_dec.shape}") + video_out_path = video_path + ".reconstructed_by_taehv.mp4" + video_out = VideoTensorWriter(video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps))) + for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]: + video_out.write(frame) + print(f" Saved to {video_out_path}") + +if __name__ == "__main__": + main()