1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-28 12:20:55 +03:00

Allow using tiny vae in other precisions

The model is available in fp16 only though
This commit is contained in:
kijai
2025-10-11 07:47:52 +03:00
parent f1d1c83713
commit e2d8c9bef5
2 changed files with 7 additions and 5 deletions

View File

@@ -1649,9 +1649,9 @@ class WanVideoTinyVAELoader:
model_path = folder_paths.get_full_path("vae_approx", model_name)
vae_sd = load_torch_file(model_path, safe_load=True)
vae = TAEHV(vae_sd, parallel=parallel)
vae.to(device = offload_device, dtype = dtype)
vae = TAEHV(vae_sd, parallel=parallel, dtype=dtype)
vae.to(device=offload_device, dtype=dtype)
return (vae,)

View File

@@ -146,7 +146,7 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
return x
class TAEHV(nn.Module):
def __init__(self, state_dict, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
def __init__(self, state_dict, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), dtype=torch.float16):
"""Initialize pretrained TAEHV from the given checkpoint.
Arg:
@@ -160,6 +160,8 @@ class TAEHV(nn.Module):
self.patch_size = 1
if self.latent_channels == 48:
self.patch_size = 2
self.dtype = dtype
self.encoder = nn.Sequential(
conv(self.image_channels*self.patch_size**2, 64), nn.ReLU(inplace=True),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
@@ -178,7 +180,7 @@ class TAEHV(nn.Module):
)
if state_dict is not None:
self.load_state_dict(self.patch_tgrow_layers(state_dict))
self.dtype = torch.float16
self.parallel = parallel
def patch_tgrow_layers(self, sd):