You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-26 23:41:35 +03:00
support tiny vae
This commit is contained in:
189
latent_preview.py
Normal file
189
latent_preview.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from comfy.cli_args import args, LatentPreviewMethod
|
||||
from comfy.taesd.taesd import TAESD
|
||||
import comfy.model_management
|
||||
import folder_paths
|
||||
import comfy.utils
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .taehv import TAEHV
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
|
||||
def preview_to_image(latent_image):
|
||||
print("latent_image shape: ", latent_image.shape)#torch.Size([60, 104, 3])
|
||||
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
|
||||
.mul(0xFF) # to 0..255
|
||||
)
|
||||
if comfy.model_management.directml_enabled:
|
||||
latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
|
||||
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=comfy.model_management.device_supports_non_blocking(latent_image.device))
|
||||
|
||||
return Image.fromarray(latents_ubyte.numpy())
|
||||
|
||||
class LatentPreviewer:
|
||||
def decode_latent_to_preview(self, x0):
|
||||
pass
|
||||
|
||||
def decode_latent_to_preview_image(self, preview_format, x0):
|
||||
preview_image = self.decode_latent_to_preview(x0)
|
||||
return ("JPEG", preview_image, MAX_PREVIEW_RESOLUTION)
|
||||
|
||||
class TAESDPreviewerImpl(LatentPreviewer):
|
||||
def __init__(self, taesd):
|
||||
self.taesd = taesd
|
||||
|
||||
# def decode_latent_to_preview(self, x0):
|
||||
# #x_sample = self.taesd.decode(x0[:1])[0].movedim(0, 2)
|
||||
# print("x0 shape: ", x0.shape) #torch.Size([5, 16, 60, 104])
|
||||
# x0 = x0.unsqueeze(0)
|
||||
# print("x0 shape: ", x0.shape) #torch.Size([5, 16, 60, 104])
|
||||
# x_sample = self.taesd.decode_video(x0, parallel=False)[0].permute(0, 2, 3, 1)[0]
|
||||
# print("x_sample shape: ", x_sample.shape)
|
||||
# return preview_to_image(x_sample)
|
||||
|
||||
|
||||
class Latent2RGBPreviewer(LatentPreviewer):
|
||||
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
|
||||
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
|
||||
self.latent_rgb_factors_bias = None
|
||||
if latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
|
||||
|
||||
def decode_latent_to_preview(self, x0):
|
||||
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
|
||||
if self.latent_rgb_factors_bias is not None:
|
||||
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
|
||||
|
||||
if x0.ndim == 5:
|
||||
x0 = x0[0, :, 0]
|
||||
else:
|
||||
x0 = x0[0]
|
||||
|
||||
latent_image = torch.nn.functional.linear(x0.movedim(0, -1), self.latent_rgb_factors, bias=self.latent_rgb_factors_bias)
|
||||
# latent_image = x0[0].permute(1, 2, 0) @ self.latent_rgb_factors
|
||||
|
||||
return preview_to_image(latent_image)
|
||||
|
||||
|
||||
def get_previewer(device, latent_format):
|
||||
previewer = None
|
||||
method = args.preview_method
|
||||
if method != LatentPreviewMethod.NoPreviews:
|
||||
# TODO previewer methods
|
||||
taesd_decoder_path = None
|
||||
|
||||
if method == LatentPreviewMethod.Auto:
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
|
||||
if method == LatentPreviewMethod.TAESD:
|
||||
taehv_path = os.path.join(folder_paths.models_dir, "vae_approx", "taew2_1.safetensors")
|
||||
if not os.path.exists(taehv_path):
|
||||
raise RuntimeError(f"Could not find {taehv_path}")
|
||||
taew_sd = comfy.utils.load_torch_file(taehv_path)
|
||||
taesd = TAEHV(taew_sd).to(device)
|
||||
previewer = TAESDPreviewerImpl(taesd)
|
||||
previewer = WrappedPreviewer(previewer)
|
||||
|
||||
if previewer is None:
|
||||
if latent_format.latent_rgb_factors is not None:
|
||||
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
|
||||
return previewer
|
||||
|
||||
def prepare_callback(model, steps, x0_output_dict=None):
|
||||
preview_format = "JPEG"
|
||||
if preview_format not in ["JPEG", "PNG"]:
|
||||
preview_format = "JPEG"
|
||||
|
||||
previewer = get_previewer(model.load_device, model.model.latent_format)
|
||||
print("previewer: ", previewer)
|
||||
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
def callback(step, x0, x, total_steps):
|
||||
if x0_output_dict is not None:
|
||||
x0_output_dict["x0"] = x0
|
||||
|
||||
preview_bytes = None
|
||||
if previewer:
|
||||
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||
return callback
|
||||
|
||||
#borrowed VideoHelperSuite https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite/blob/main/videohelpersuite/latent_preview.py
|
||||
import server
|
||||
from threading import Thread
|
||||
import torch.nn.functional as F
|
||||
import io
|
||||
import time
|
||||
serv = server.PromptServer.instance
|
||||
|
||||
class WrappedPreviewer(LatentPreviewer):
|
||||
def __init__(self, previewer, rate=16):
|
||||
self.first_preview = True
|
||||
self.last_time = 0
|
||||
self.c_index = 0
|
||||
self.rate = rate
|
||||
if hasattr(previewer, 'taesd'):
|
||||
self.taesd = previewer.taesd
|
||||
elif hasattr(previewer, 'latent_rgb_factors'):
|
||||
self.latent_rgb_factors = previewer.latent_rgb_factors
|
||||
self.latent_rgb_factors_bias = previewer.latent_rgb_factors_bias
|
||||
else:
|
||||
raise Exception('Unsupported preview type for VHS animated previews')
|
||||
|
||||
def decode_latent_to_preview_image(self, preview_format, x0):
|
||||
if x0.ndim == 5:
|
||||
#Keep batch major
|
||||
x0 = x0.movedim(2,1)
|
||||
x0 = x0.reshape((-1,)+x0.shape[-3:])
|
||||
num_images = x0.size(0)
|
||||
new_time = time.time()
|
||||
num_previews = int((new_time - self.last_time) * self.rate)
|
||||
self.last_time = self.last_time + num_previews/self.rate
|
||||
if num_previews > num_images:
|
||||
num_previews = num_images
|
||||
elif num_previews <= 0:
|
||||
return None
|
||||
if self.first_preview:
|
||||
self.first_preview = False
|
||||
serv.send_sync('VHS_latentpreview', {'length':num_images, 'rate': self.rate})
|
||||
self.last_time = new_time + 1/self.rate
|
||||
if self.c_index + num_previews > num_images:
|
||||
x0 = x0.roll(-self.c_index, 0)[:num_previews]
|
||||
else:
|
||||
x0 = x0[self.c_index:self.c_index + num_previews]
|
||||
Thread(target=self.process_previews, args=(x0, self.c_index,
|
||||
num_images)).run()
|
||||
self.c_index = (self.c_index + num_previews) % num_images
|
||||
return None
|
||||
def process_previews(self, image_tensor, ind, leng):
|
||||
max_size = 256
|
||||
image_tensor = self.decode_latent_to_preview(image_tensor)
|
||||
if image_tensor.size(1) > max_size or image_tensor.size(2) > max_size:
|
||||
image_tensor = image_tensor.movedim(-1,0)
|
||||
if image_tensor.size(2) < image_tensor.size(3):
|
||||
height = (max_size * image_tensor.size(2)) // image_tensor.size(3)
|
||||
image_tensor = F.interpolate(image_tensor, (height,max_size), mode='bilinear')
|
||||
else:
|
||||
width = (max_size * image_tensor.size(3)) // image_tensor.size(2)
|
||||
image_tensor = F.interpolate(image_tensor, (max_size, width), mode='bilinear')
|
||||
image_tensor = image_tensor.movedim(0,-1)
|
||||
previews_ubyte = (image_tensor.clamp(0, 1)
|
||||
.mul(0xFF) # to 0..255
|
||||
).to(device="cpu", dtype=torch.uint8)
|
||||
for preview in previews_ubyte:
|
||||
i = Image.fromarray(preview.numpy())
|
||||
message = io.BytesIO()
|
||||
message.write((1).to_bytes(length=4, byteorder='big')*2)
|
||||
message.write(ind.to_bytes(length=4, byteorder='big'))
|
||||
i.save(message, format="JPEG", quality=95, compress_level=1)
|
||||
#NOTE: send sync already uses call_soon_threadsafe
|
||||
serv.send_sync(server.BinaryEventTypes.PREVIEW_IMAGE,
|
||||
message.getvalue(), serv.client_id)
|
||||
ind = (ind + 1) % leng
|
||||
def decode_latent_to_preview(self, x0):
|
||||
x0 = x0.unsqueeze(0)
|
||||
x_sample = self.taesd.decode_video(x0, parallel=False)[0].permute(0, 2, 3, 1)
|
||||
return x_sample
|
||||
75
nodes.py
75
nodes.py
@@ -15,6 +15,7 @@ from .wanvideo.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
from .enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight, set_num_frames
|
||||
from .taehv import TAEHV
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
@@ -678,6 +679,42 @@ class WanVideoVAELoader:
|
||||
|
||||
return (vae,)
|
||||
|
||||
class WanVideoTinyVAELoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model_name": (folder_paths.get_filename_list("vae_approx"), {"tooltip": "These models are loaded from 'ComfyUI/models/vae_approx'"}),
|
||||
},
|
||||
"optional": {
|
||||
"precision": (["fp16", "fp32", "bf16"],
|
||||
{"default": "fp16"}
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVAE",)
|
||||
RETURN_NAMES = ("vae", )
|
||||
FUNCTION = "loadmodel"
|
||||
CATEGORY = "WanVideoWrapper"
|
||||
DESCRIPTION = "Loads Hunyuan VAE model from 'ComfyUI/models/vae'"
|
||||
|
||||
def loadmodel(self, model_name, precision):
|
||||
from .taehv import TAEHV
|
||||
|
||||
device = mm.get_torch_device()
|
||||
offload_device = mm.unet_offload_device()
|
||||
|
||||
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
|
||||
model_path = folder_paths.get_full_path("vae_approx", model_name)
|
||||
vae_sd = load_torch_file(model_path, safe_load=True)
|
||||
|
||||
vae = TAEHV(vae_sd)
|
||||
|
||||
vae.to(device = offload_device, dtype = dtype)
|
||||
|
||||
return (vae,)
|
||||
|
||||
|
||||
|
||||
class WanVideoTorchCompileSettings:
|
||||
@@ -1403,7 +1440,7 @@ class WanVideoSampler:
|
||||
|
||||
pbar = ProgressBar(steps)
|
||||
|
||||
from latent_preview import prepare_callback
|
||||
from .latent_preview import prepare_callback
|
||||
callback = prepare_callback(patcher, steps)
|
||||
|
||||
#blockswap init
|
||||
@@ -1778,9 +1815,13 @@ class WanVideoSampler:
|
||||
to_decode = self.previous_noise_pred_context[:,-1,:, :].unsqueeze(1).unsqueeze(0).to(context_vae.dtype)
|
||||
#to_decode = to_decode.permute(0, 1, 3, 2)
|
||||
#print("to_decode.shape", to_decode.shape)
|
||||
image = context_vae.decode(to_decode, device=device, tiled=False)[0]
|
||||
if isinstance(context_vae, TAEHV):
|
||||
image = context_vae.decode_video(to_decode.permute(0, 2, 1, 3, 4), parallel=False)[0].permute(1, 0, 2, 3)
|
||||
image = context_vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False).permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
image = context_vae.decode(to_decode, device=device, tiled=False)[0]
|
||||
image = context_vae.encode(image.unsqueeze(0).to(context_vae.dtype), device=device, tiled=False)
|
||||
#print("decoded image.shape", image.shape) #torch.Size([3, 37, 832, 480])
|
||||
image = context_vae.encode(image.unsqueeze(0).to(context_vae.dtype), device=device, tiled=False)
|
||||
#print("encoded image.shape", image.shape)
|
||||
#partial_img_emb[:, 0, :, :] = image[0][:,0,:,:]
|
||||
#print("partial_img_emb.shape", partial_img_emb.shape)
|
||||
@@ -1934,14 +1975,17 @@ class WanVideoDecode:
|
||||
|
||||
mm.soft_empty_cache()
|
||||
|
||||
image = vae.decode(latents, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))[0]
|
||||
print(image.shape)
|
||||
print(image.min(), image.max())
|
||||
if isinstance(vae, TAEHV):
|
||||
image = vae.decode_video(latents.permute(0, 2, 1, 3, 4))[0].permute(1, 0, 2, 3)
|
||||
else:
|
||||
image = vae.decode(latents, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))[0]
|
||||
vae.model.clear_cache()
|
||||
image = (image - image.min()) / (image.max() - image.min())
|
||||
vae.to(offload_device)
|
||||
vae.model.clear_cache()
|
||||
|
||||
mm.soft_empty_cache()
|
||||
|
||||
image = (image - image.min()) / (image.max() - image.min())
|
||||
|
||||
image = torch.clamp(image, 0.0, 1.0)
|
||||
image = image.permute(1, 2, 3, 0).cpu().float()
|
||||
|
||||
@@ -1978,16 +2022,21 @@ class WanVideoEncode:
|
||||
|
||||
vae.to(device)
|
||||
|
||||
image = (image.clone() * 2.0 - 1.0).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
|
||||
image = (image.clone()).to(vae.dtype).to(device).unsqueeze(0).permute(0, 4, 1, 2, 3) # B, C, T, H, W
|
||||
if noise_aug_strength > 0.0:
|
||||
image = add_noise_to_reference_video(image, ratio=noise_aug_strength)
|
||||
|
||||
latents = vae.encode(image, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))
|
||||
|
||||
if isinstance(vae, TAEHV):
|
||||
latents = vae.encode_video(image.permute(0, 2, 1, 3, 4), parallel=False)# B, T, C, H, W
|
||||
latents = latents.permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
latents = vae.encode(image * 2.0 - 1.0, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))
|
||||
vae.model.clear_cache()
|
||||
if latent_strength != 1.0:
|
||||
latents *= latent_strength
|
||||
|
||||
vae.to(offload_device)
|
||||
vae.model.clear_cache()
|
||||
|
||||
mm.soft_empty_cache()
|
||||
print("encoded latents shape",latents.shape)
|
||||
|
||||
@@ -2121,6 +2170,7 @@ NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoFlowEdit": WanVideoFlowEdit,
|
||||
"WanVideoControlEmbeds": WanVideoControlEmbeds,
|
||||
"WanVideoSLG": WanVideoSLG,
|
||||
"WanVideoTinyVAELoader": WanVideoTinyVAELoader,
|
||||
}
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoSampler": "WanVideo Sampler",
|
||||
@@ -2147,4 +2197,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoFlowEdit": "WanVideo FlowEdit",
|
||||
"WanVideoControlEmbeds": "WanVideo Control Embeds",
|
||||
"WanVideoSLG": "WanVideo SLG",
|
||||
"WanVideoTinyVAELoader": "WanVideo Tiny VAE Loader",
|
||||
}
|
||||
|
||||
1
taehv/__init__.py
Normal file
1
taehv/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .taehv import TAEHV
|
||||
284
taehv/taehv.py
Normal file
284
taehv/taehv.py
Normal file
@@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tiny AutoEncoder for Hunyuan Video
|
||||
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm.auto import tqdm
|
||||
from collections import namedtuple
|
||||
|
||||
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
||||
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
class Clamp(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
class MemBlock(nn.Module):
|
||||
def __init__(self, n_in, n_out):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
|
||||
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.act = nn.ReLU(inplace=True)
|
||||
def forward(self, x, past):
|
||||
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
||||
|
||||
class TPool(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = nn.Conv2d(n_f*stride,n_f, 1, bias=False)
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
||||
|
||||
class TGrow(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
x = self.conv(x)
|
||||
return x.reshape(-1, C, H, W)
|
||||
|
||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
"""
|
||||
Apply a sequential model with memblocks to the given input.
|
||||
Args:
|
||||
- model: nn.Sequential of blocks to apply
|
||||
- x: input data, of dimensions NTCHW
|
||||
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
|
||||
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
|
||||
- show_progress_bar: if True, enables tqdm progressbar display
|
||||
|
||||
Returns NTCHW tensor of output data.
|
||||
"""
|
||||
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
|
||||
N, T, C, H, W = x.shape
|
||||
if parallel:
|
||||
x = x.reshape(N*T, C, H, W)
|
||||
# parallel over input timesteps, iterate over blocks
|
||||
for b in tqdm(model, disable=not show_progress_bar):
|
||||
if isinstance(b, MemBlock):
|
||||
NT, C, H, W = x.shape
|
||||
T = NT // N
|
||||
_x = x.reshape(N, T, C, H, W)
|
||||
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
|
||||
x = b(x, mem)
|
||||
else:
|
||||
x = b(x)
|
||||
NT, C, H, W = x.shape
|
||||
T = NT // N
|
||||
x = x.view(N, T, C, H, W)
|
||||
else:
|
||||
# TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
|
||||
# need to fix :(
|
||||
out = []
|
||||
# iterate over input timesteps and also iterate over blocks.
|
||||
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
|
||||
# it's actually a ***graph traversal*** problem! so let's make a queue
|
||||
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
|
||||
# in addition to manually managing our queue, we also need to manually manage our progressbar.
|
||||
# we'll update it for every source node that we consume.
|
||||
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
||||
# we'll also need a separate addressable memory per node as well
|
||||
mem = [None] * len(model)
|
||||
while work_queue:
|
||||
xt, i = work_queue.pop(0)
|
||||
if i == 0:
|
||||
# new source node consumed
|
||||
progress_bar.update(1)
|
||||
if i == len(model):
|
||||
# reached end of the graph, append result to output list
|
||||
out.append(xt)
|
||||
else:
|
||||
# fetch the block to process
|
||||
b = model[i]
|
||||
if isinstance(b, MemBlock):
|
||||
# mem blocks are simple since we're visiting the graph in causal order
|
||||
if mem[i] is None:
|
||||
xt_new = b(xt, xt * 0)
|
||||
mem[i] = xt
|
||||
else:
|
||||
xt_new = b(xt, mem[i])
|
||||
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
|
||||
# add successor to work queue
|
||||
work_queue.insert(0, TWorkItem(xt_new, i+1))
|
||||
elif isinstance(b, TPool):
|
||||
# pool blocks are miserable
|
||||
if mem[i] is None:
|
||||
mem[i] = [] # pool memory is itself a queue of inputs to pool
|
||||
mem[i].append(xt)
|
||||
if len(mem[i]) > b.stride:
|
||||
# pool mem is in invalid state, we should have pooled before this
|
||||
raise ValueError("???")
|
||||
elif len(mem[i]) < b.stride:
|
||||
# pool mem is not yet full, go back to processing the work queue
|
||||
pass
|
||||
else:
|
||||
# pool mem is ready, run the pool block
|
||||
N, C, H, W = xt.shape
|
||||
xt = b(torch.cat(mem[i], 1).view(N*b.stride, C, H, W))
|
||||
# reset the pool mem
|
||||
mem[i] = []
|
||||
# add successor to work queue
|
||||
work_queue.insert(0, TWorkItem(xt, i+1))
|
||||
elif isinstance(b, TGrow):
|
||||
xt = b(xt)
|
||||
NT, C, H, W = xt.shape
|
||||
# each tgrow has multiple successor nodes
|
||||
for xt_next in reversed(xt.view(N, b.stride*C, H, W).chunk(b.stride, 1)):
|
||||
# add successor to work queue
|
||||
work_queue.insert(0, TWorkItem(xt_next, i+1))
|
||||
else:
|
||||
# normal block with no funny business
|
||||
xt = b(xt)
|
||||
# add successor to work queue
|
||||
work_queue.insert(0, TWorkItem(xt, i+1))
|
||||
progress_bar.close()
|
||||
x = torch.stack(out, 1)
|
||||
return x
|
||||
|
||||
class TAEHV(nn.Module):
|
||||
latent_channels = 16
|
||||
image_channels = 3
|
||||
def __init__(self, state_dict, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
|
||||
"""Initialize pretrained TAEHV from the given checkpoint.
|
||||
|
||||
Arg:
|
||||
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
|
||||
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
|
||||
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
|
||||
"""
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
||||
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
||||
conv(64, TAEHV.latent_channels),
|
||||
)
|
||||
n_f = [256, 128, 64, 64]
|
||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
||||
self.decoder = nn.Sequential(
|
||||
Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
|
||||
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
|
||||
)
|
||||
if state_dict is not None:
|
||||
self.load_state_dict(self.patch_tgrow_layers(state_dict))
|
||||
self.dtype = torch.float16
|
||||
|
||||
def patch_tgrow_layers(self, sd):
|
||||
"""Patch TGrow layers to use a smaller kernel if needed.
|
||||
|
||||
Args:
|
||||
sd: state dict to patch
|
||||
"""
|
||||
new_sd = self.state_dict()
|
||||
for i, layer in enumerate(self.decoder):
|
||||
if isinstance(layer, TGrow):
|
||||
key = f"decoder.{i}.conv.weight"
|
||||
if sd[key].shape[0] > new_sd[key].shape[0]:
|
||||
# take the last-timestep output channels
|
||||
sd[key] = sd[key][-new_sd[key].shape[0]:]
|
||||
return sd
|
||||
|
||||
def encode_video(self, x, parallel=True, show_progress_bar=True):
|
||||
"""Encode a sequence of frames.
|
||||
|
||||
Args:
|
||||
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
|
||||
parallel: if True, all frames will be processed at once.
|
||||
(this is faster but may require more memory).
|
||||
if False, frames will be processed sequentially.
|
||||
Returns NTCHW latent tensor with ~Gaussian values.
|
||||
"""
|
||||
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
|
||||
|
||||
def decode_video(self, x, parallel=True, show_progress_bar=True):
|
||||
"""Decode a sequence of frames.
|
||||
|
||||
Args:
|
||||
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
|
||||
parallel: if True, all frames will be processed at once.
|
||||
(this is faster but may require more memory).
|
||||
if False, frames will be processed sequentially.
|
||||
Returns NTCHW RGB tensor with ~[0, 1] values.
|
||||
"""
|
||||
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
|
||||
return x[:, self.frames_to_trim:]
|
||||
|
||||
def forward(self, x):
|
||||
return self.c(x)
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
"""Run TAEHV roundtrip reconstruction on the given video paths."""
|
||||
import sys
|
||||
import cv2 # no highly esteemed deed is commemorated here
|
||||
|
||||
class VideoTensorReader:
|
||||
def __init__(self, video_file_path):
|
||||
self.cap = cv2.VideoCapture(video_file_path)
|
||||
assert self.cap.isOpened(), f"Could not load {video_file_path}"
|
||||
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
|
||||
def __iter__(self):
|
||||
return self
|
||||
def __next__(self):
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
self.cap.release()
|
||||
raise StopIteration # End of video or error
|
||||
return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
|
||||
|
||||
class VideoTensorWriter:
|
||||
def __init__(self, video_file_path, width_height, fps=30):
|
||||
self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height)
|
||||
assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
|
||||
def write(self, frame_tensor):
|
||||
assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
|
||||
self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
|
||||
def __del__(self):
|
||||
if hasattr(self, 'writer'): self.writer.release()
|
||||
|
||||
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
||||
dtype = torch.float16
|
||||
print("Using device", dev, "and dtype", dtype)
|
||||
taehv = TAEHV().to(dev, dtype)
|
||||
for video_path in sys.argv[1:]:
|
||||
print(f"Processing {video_path}...")
|
||||
video_in = VideoTensorReader(video_path)
|
||||
video = torch.stack(list(video_in), 0)[None]
|
||||
vid_dev = video.to(dev, dtype).div_(255.0)
|
||||
# convert to device tensor
|
||||
if video.numel() < 100_000_000:
|
||||
print(f" {video_path} seems small enough, will process all frames in parallel")
|
||||
# convert to device tensor
|
||||
vid_enc = taehv.encode_video(vid_dev)
|
||||
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
|
||||
vid_dec = taehv.decode_video(vid_enc)
|
||||
print(f" Decoded {video_path} -> {vid_dec.shape}")
|
||||
else:
|
||||
print(f" {video_path} seems large, will process each frame sequentially")
|
||||
# convert to device tensor
|
||||
vid_enc = taehv.encode_video(vid_dev, parallel=False)
|
||||
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
|
||||
vid_dec = taehv.decode_video(vid_enc, parallel=False)
|
||||
print(f" Decoded {video_path} -> {vid_dec.shape}")
|
||||
video_out_path = video_path + ".reconstructed_by_taehv.mp4"
|
||||
video_out = VideoTensorWriter(video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
|
||||
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
|
||||
video_out.write(frame)
|
||||
print(f" Saved to {video_out_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user