1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-28 12:20:55 +03:00

Support 2.2 tiny VAE

This commit is contained in:
kijai
2025-10-10 22:13:12 +03:00
parent 6d2ff33466
commit f1d1c83713
2 changed files with 23 additions and 18 deletions

View File

@@ -77,18 +77,18 @@ def get_previewer(device, latent_format):
method = LatentPreviewMethod.Latent2RGB
if method == LatentPreviewMethod.TAESD:
if latent_format == Wan22: # No TAEW currently available for Wan2.2 VAE
method = LatentPreviewMethod.Latent2RGB
else:
try:
try:
if latent_format == Wan22:
taehv_path = folder_paths.get_full_path("vae_approx", "taew2_2.safetensors")
else:
taehv_path = folder_paths.get_full_path("vae_approx", "taew2_1.safetensors")
taesd = TAEHV(comfy.utils.load_torch_file(taehv_path)).to(device)
previewer = TAESDPreviewerImpl(taesd)
previewer = WrappedPreviewer(previewer, rate=16)
except:
log.info("Could not find TAEW model file 'taew2_1.safetensors' from models/vae_approx. You can download it from https://huggingface.co/Kijai/WanVideo_comfy/blob/main/taew2_1.safetensors")
log.info("Using Latent2RGB previewer instead.")
method = LatentPreviewMethod.Latent2RGB
taesd = TAEHV(comfy.utils.load_torch_file(taehv_path)).to(device)
previewer = TAESDPreviewerImpl(taesd)
previewer = WrappedPreviewer(previewer, rate=16)
except:
log.info("Could not find TAEW model file 'taew2_1.safetensors' from models/vae_approx. You can download it from https://huggingface.co/Kijai/WanVideo_comfy/blob/main/taew2_1.safetensors")
log.info("Using Latent2RGB previewer instead.")
method = LatentPreviewMethod.Latent2RGB
if previewer is None:
if latent_format.latent_rgb_factors is not None:

View File

@@ -146,8 +146,6 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
return x
class TAEHV(nn.Module):
latent_channels = 16
image_channels = 3
def __init__(self, state_dict, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
"""Initialize pretrained TAEHV from the given checkpoint.
@@ -157,21 +155,26 @@ class TAEHV(nn.Module):
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
"""
super().__init__()
self.image_channels = 3
self.latent_channels = state_dict["decoder.1.weight"].shape[1]
self.patch_size = 1
if self.latent_channels == 48:
self.patch_size = 2
self.encoder = nn.Sequential(
conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
conv(self.image_channels*self.patch_size**2, 64), nn.ReLU(inplace=True),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
conv(64, TAEHV.latent_channels),
conv(64, self.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
nn.ReLU(inplace=True), conv(n_f[3], self.image_channels*self.patch_size**2),
)
if state_dict is not None:
self.load_state_dict(self.patch_tgrow_layers(state_dict))
@@ -203,7 +206,8 @@ class TAEHV(nn.Module):
if False, frames will be processed sequentially.
Returns NTCHW latent tensor with ~Gaussian values.
"""
return apply_model_with_memblocks(self.encoder, x, self.parallel, show_progress_bar)
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
def decode_video(self, x, parallel=False, show_progress_bar=True):
"""Decode a sequence of frames.
@@ -216,6 +220,7 @@ class TAEHV(nn.Module):
Returns NTCHW RGB tensor with ~[0, 1] values.
"""
x = apply_model_with_memblocks(self.decoder, x, self.parallel, show_progress_bar)
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
return x[:, self.frames_to_trim:]
def forward(self, x):