You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-28 12:20:55 +03:00
Support 2.2 tiny VAE
This commit is contained in:
@@ -77,18 +77,18 @@ def get_previewer(device, latent_format):
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
|
||||
if method == LatentPreviewMethod.TAESD:
|
||||
if latent_format == Wan22: # No TAEW currently available for Wan2.2 VAE
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
else:
|
||||
try:
|
||||
try:
|
||||
if latent_format == Wan22:
|
||||
taehv_path = folder_paths.get_full_path("vae_approx", "taew2_2.safetensors")
|
||||
else:
|
||||
taehv_path = folder_paths.get_full_path("vae_approx", "taew2_1.safetensors")
|
||||
taesd = TAEHV(comfy.utils.load_torch_file(taehv_path)).to(device)
|
||||
previewer = TAESDPreviewerImpl(taesd)
|
||||
previewer = WrappedPreviewer(previewer, rate=16)
|
||||
except:
|
||||
log.info("Could not find TAEW model file 'taew2_1.safetensors' from models/vae_approx. You can download it from https://huggingface.co/Kijai/WanVideo_comfy/blob/main/taew2_1.safetensors")
|
||||
log.info("Using Latent2RGB previewer instead.")
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
taesd = TAEHV(comfy.utils.load_torch_file(taehv_path)).to(device)
|
||||
previewer = TAESDPreviewerImpl(taesd)
|
||||
previewer = WrappedPreviewer(previewer, rate=16)
|
||||
except:
|
||||
log.info("Could not find TAEW model file 'taew2_1.safetensors' from models/vae_approx. You can download it from https://huggingface.co/Kijai/WanVideo_comfy/blob/main/taew2_1.safetensors")
|
||||
log.info("Using Latent2RGB previewer instead.")
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
|
||||
if previewer is None:
|
||||
if latent_format.latent_rgb_factors is not None:
|
||||
|
||||
@@ -146,8 +146,6 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
return x
|
||||
|
||||
class TAEHV(nn.Module):
|
||||
latent_channels = 16
|
||||
image_channels = 3
|
||||
def __init__(self, state_dict, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
|
||||
"""Initialize pretrained TAEHV from the given checkpoint.
|
||||
|
||||
@@ -157,21 +155,26 @@ class TAEHV(nn.Module):
|
||||
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
|
||||
"""
|
||||
super().__init__()
|
||||
self.image_channels = 3
|
||||
self.latent_channels = state_dict["decoder.1.weight"].shape[1]
|
||||
self.patch_size = 1
|
||||
if self.latent_channels == 48:
|
||||
self.patch_size = 2
|
||||
self.encoder = nn.Sequential(
|
||||
conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
|
||||
conv(self.image_channels*self.patch_size**2, 64), nn.ReLU(inplace=True),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
||||
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
||||
conv(64, TAEHV.latent_channels),
|
||||
conv(64, self.latent_channels),
|
||||
)
|
||||
n_f = [256, 128, 64, 64]
|
||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
||||
self.decoder = nn.Sequential(
|
||||
Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
|
||||
Clamp(), conv(self.latent_channels, n_f[0]), nn.ReLU(inplace=True),
|
||||
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
|
||||
nn.ReLU(inplace=True), conv(n_f[3], self.image_channels*self.patch_size**2),
|
||||
)
|
||||
if state_dict is not None:
|
||||
self.load_state_dict(self.patch_tgrow_layers(state_dict))
|
||||
@@ -203,7 +206,8 @@ class TAEHV(nn.Module):
|
||||
if False, frames will be processed sequentially.
|
||||
Returns NTCHW latent tensor with ~Gaussian values.
|
||||
"""
|
||||
return apply_model_with_memblocks(self.encoder, x, self.parallel, show_progress_bar)
|
||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
||||
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
|
||||
|
||||
def decode_video(self, x, parallel=False, show_progress_bar=True):
|
||||
"""Decode a sequence of frames.
|
||||
@@ -216,6 +220,7 @@ class TAEHV(nn.Module):
|
||||
Returns NTCHW RGB tensor with ~[0, 1] values.
|
||||
"""
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, show_progress_bar)
|
||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
||||
return x[:, self.frames_to_trim:]
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
Reference in New Issue
Block a user