diff --git a/nodes.py b/nodes.py index aeb8832..a000311 100644 --- a/nodes.py +++ b/nodes.py @@ -1019,6 +1019,8 @@ class WanVideoDecode: print(image.shape) print(image.min(), image.max()) vae.to(offload_device) + vae.model.clear_cache() + mm.soft_empty_cache() image = (image - image.min()) / (image.max() - image.min()) image = torch.clamp(image, 0.0, 1.0) @@ -1061,12 +1063,13 @@ class WanVideoEncode: if noise_aug_strength > 0.0: image = add_noise_to_reference_video(image, ratio=noise_aug_strength) - latents = vae.encode(image, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y))#.latent_dist.sample(generator) + latents = vae.encode(image, device=device, tiled=enable_vae_tiling, tile_size=(tile_x, tile_y), tile_stride=(tile_stride_x, tile_stride_y)) if latent_strength != 1.0: latents *= latent_strength - #latents = latents * vae.config.scaling_factor vae.to(offload_device) + vae.model.clear_cache() + mm.soft_empty_cache() print("encoded latents shape",latents.shape)