You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-28 12:20:55 +03:00
Allow using tiny vae in other precisions
The model is available in fp16 only though
This commit is contained in:
@@ -1649,9 +1649,9 @@ class WanVideoTinyVAELoader:
|
||||
model_path = folder_paths.get_full_path("vae_approx", model_name)
|
||||
vae_sd = load_torch_file(model_path, safe_load=True)
|
||||
|
||||
vae = TAEHV(vae_sd, parallel=parallel)
|
||||
|
||||
vae.to(device = offload_device, dtype = dtype)
|
||||
vae = TAEHV(vae_sd, parallel=parallel, dtype=dtype)
|
||||
|
||||
vae.to(device=offload_device, dtype=dtype)
|
||||
|
||||
return (vae,)
|
||||
|
||||
|
||||
@@ -146,7 +146,7 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
return x
|
||||
|
||||
class TAEHV(nn.Module):
|
||||
def __init__(self, state_dict, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
|
||||
def __init__(self, state_dict, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), dtype=torch.float16):
|
||||
"""Initialize pretrained TAEHV from the given checkpoint.
|
||||
|
||||
Arg:
|
||||
@@ -160,6 +160,8 @@ class TAEHV(nn.Module):
|
||||
self.patch_size = 1
|
||||
if self.latent_channels == 48:
|
||||
self.patch_size = 2
|
||||
self.dtype = dtype
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
conv(self.image_channels*self.patch_size**2, 64), nn.ReLU(inplace=True),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
||||
@@ -178,7 +180,7 @@ class TAEHV(nn.Module):
|
||||
)
|
||||
if state_dict is not None:
|
||||
self.load_state_dict(self.patch_tgrow_layers(state_dict))
|
||||
self.dtype = torch.float16
|
||||
|
||||
self.parallel = parallel
|
||||
|
||||
def patch_tgrow_layers(self, sd):
|
||||
|
||||
Reference in New Issue
Block a user