From be2f557fa62f7182a1c28eb231ccb00a1626df37 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 26 Feb 2025 15:54:59 +0200 Subject: [PATCH] free VAE cache Somehow missed this, should free up 1-2GB VRAM --- nodes.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index aeb8832..a000311 100644 --- a/nodes.py +++ b/nodes.py @@ -1019,6 +1019,8 @@ class WanVideoDecode: print(image.shape) print(image.min(), image.max()) vae.to(offload_device) + vae.model.clear_cache() + mm.soft_empty_cache() image = (image - image.min()) / (image.max() - image.min()) image = torch.clamp(image, 0.0, 1.0) @@ -1061,12 +1063,13 @@ class WanVideoEncode: if noise_aug_strength > 0.0: image = add_noise_to_reference_video(image, ratio=noise_aug_strength) - latents = vae.encode(image, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))#.latent_dist.sample(generator) + latents = vae.encode(image, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y)) if latent_strength != 1.0: latents *= latent_strength - #latents = latents * vae.config.scaling_factor vae.to(offload_device) + vae.model.clear_cache() + mm.soft_empty_cache() print("encoded latents shape",latents.shape)