1
0
mirror of https://github.com/kijai/ComfyUI-WanVideoWrapper.git synced 2026-01-26 23:41:35 +03:00
Files
ComfyUI-WanVideoWrapper/cache_methods/cache_methods.py
kijai 139bdf827f Squashed commit of the following:
commit 73dd1a06d33953912f5dd684f168028b14e42a36
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Oct 13 19:47:38 2025 +0300

    cleanup

commit 39bc2cecf493e2eb176b55e8841d933f0da1ec39
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Oct 13 19:24:20 2025 +0300

    Allow scheduling ovi cfg

commit 2c153c5f324dbd59670ad9c51a7995459504a3cd
Merge: dba7667 32eb6b4
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Mon Oct 13 17:48:20 2025 +0300

    Merge branch 'main' into ovi

commit dba76674c71af7bf94c82834a0b0e40d94043c99
Merge: 0f11a43 5a0456e
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sun Oct 12 22:45:43 2025 +0300

    Merge branch 'main' into ovi

commit 0f11a439622799ad8070f8a2b8cc8e6a041b761d
Merge: 0999f50 e2d8c9b
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Sat Oct 11 07:48:06 2025 +0300

    Merge branch 'main' into ovi

commit 0999f50cfe025290cd7ce88a8dd1acff0b38d9bd
Merge: d45df1f f1d1c83
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Fri Oct 10 22:16:09 2025 +0300

    Merge branch 'main' into ovi

commit d45df1fb5b7c629b15eabc197357d62bdc232aaf
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Thu Oct 9 20:21:37 2025 +0300

    Remove dependency for librosa

commit d8e7533fdf7eab1d2489c3e025a908c02d997444
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Thu Oct 9 19:57:28 2025 +0300

    Remove omegaconf dependency

commit f4e27ff018e98cb5b09655dceda399baea36b240
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Thu Oct 9 19:31:06 2025 +0300

    Fix VACE

commit 35d3df39294831e5e7568b6f7e16d2ecf2d790a0
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Thu Oct 9 00:26:40 2025 +0300

    small update

commit 96f8ea1d26869ab7e49e12a07f19d5d5a2023253
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Wed Oct 8 22:32:57 2025 +0300

    Create wanvideo_2_2_5B_ovi_testing.json

commit a2511be73b9da7019fd21aeb0b521af941c09150
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Wed Oct 8 22:32:54 2025 +0300

    Update nodes_sampler.py

commit d3688b8db71452ea1f7c9a2bc0216441d524e56c
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Wed Oct 8 21:43:02 2025 +0300

    Allow EasyCache to work with ovi

commit 586d9148a0306ef5d30e9a971a9c3be4cd3ecc97
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Wed Oct 8 19:09:06 2025 +0300

    Update model.py

commit 61eedd2839decdb7d4c2ddd5f1310fdaf49d36ad
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Wed Oct 8 19:09:02 2025 +0300

    I2V fix

commit a97fcb1b9ae9fb7bbfdf668c24816e014a1b58d1
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Wed Oct 8 17:57:28 2025 +0300

    Add nodes to set audio latent size

commit d41e42a697f3d561dabbc22566f633b5f1bbd952
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Wed Oct 8 16:42:04 2025 +0300

    Support loading mmaudio vae from .safetensors

commit 1b0e28ec41e3c97fe1f2f057fef9b9bbcb87bca7
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Wed Oct 8 16:19:53 2025 +0300

    Update nodes_sampler.py

commit fbd18f45fe85ede8edcb5aebaea7ceb5b6eab5a2
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Wed Oct 8 10:16:44 2025 +0300

    Fixes for other workflows

commit b06993b637198f7fad92208f3b3dc9a7d7f57c7f
Author: kijai <40791699+kijai@users.noreply.github.com>
Date:   Wed Oct 8 09:46:27 2025 +0300

    initial commit

    T2V works
2025-10-13 20:16:53 +03:00

159 lines
5.9 KiB
Python

from ..utils import log
import torch
def set_transformer_cache_method(transformer, timesteps, cache_args=None):
transformer.cache_device = cache_args["cache_device"]
if cache_args["cache_type"] == "TeaCache":
log.info(f"TeaCache: Using cache device: {transformer.cache_device}")
transformer.teacache_state.clear_all()
transformer.enable_teacache = True
transformer.rel_l1_thresh = cache_args["rel_l1_thresh"]
transformer.teacache_start_step = cache_args["start_step"]
transformer.teacache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
transformer.teacache_use_coefficients = cache_args["use_coefficients"]
transformer.teacache_mode = cache_args["mode"]
elif cache_args["cache_type"] == "MagCache":
log.info(f"MagCache: Using cache device: {transformer.cache_device}")
transformer.magcache_state.clear_all()
transformer.enable_magcache = True
transformer.magcache_start_step = cache_args["start_step"]
transformer.magcache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
transformer.magcache_thresh = cache_args["magcache_thresh"]
transformer.magcache_K = cache_args["magcache_K"]
elif cache_args["cache_type"] == "EasyCache":
log.info(f"EasyCache: Using cache device: {transformer.cache_device}")
transformer.easycache_state.clear_all()
transformer.enable_easycache = True
transformer.easycache_start_step = cache_args["start_step"]
transformer.easycache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
transformer.easycache_thresh = cache_args["easycache_thresh"]
return transformer
class TeaCacheState:
def __init__(self, cache_device='cpu'):
self.cache_device = cache_device
self.states = {}
self._next_pred_id = 0
def new_prediction(self, cache_device='cpu'):
"""Create new prediction state and return its ID"""
self.cache_device = cache_device
pred_id = self._next_pred_id
self._next_pred_id += 1
self.states[pred_id] = {
'previous_residual': None,
'accumulated_rel_l1_distance': 0,
'previous_modulated_input': None,
'skipped_steps': [],
}
return pred_id
def update(self, pred_id, **kwargs):
"""Update state for specific prediction"""
if pred_id not in self.states:
return None
for key, value in kwargs.items():
self.states[pred_id][key] = value
def get(self, pred_id):
return self.states.get(pred_id, {})
def clear_all(self):
self.states = {}
self._next_pred_id = 0
class MagCacheState:
def __init__(self, cache_device='cpu'):
self.cache_device = cache_device
self.states = {}
self._next_pred_id = 0
def new_prediction(self, cache_device='cpu'):
"""Create new prediction state and return its ID"""
self.cache_device = cache_device
pred_id = self._next_pred_id
self._next_pred_id += 1
self.states[pred_id] = {
'residual_cache': None,
'accumulated_ratio': 1.0,
'accumulated_steps': 0,
'accumulated_err': 0,
'skipped_steps': [],
}
return pred_id
def update(self, pred_id, **kwargs):
"""Update state for specific prediction"""
if pred_id not in self.states:
return None
for key, value in kwargs.items():
self.states[pred_id][key] = value
def get(self, pred_id):
return self.states.get(pred_id, {})
def clear_all(self):
self.states = {}
self._next_pred_id = 0
class EasyCacheState:
def __init__(self, cache_device='cpu'):
self.cache_device = cache_device
self.states = {}
self._next_pred_id = 0
def new_prediction(self, cache_device='cpu'):
"""Create a new prediction state and return its ID."""
self.cache_device = cache_device
pred_id = self._next_pred_id
self._next_pred_id += 1
self.states[pred_id] = {
'previous_raw_input': None,
'previous_raw_output': None,
'cache': None,
'accumulated_error': 0.0,
'skipped_steps': [],
'cache_ovi': None,
}
return pred_id
def update(self, pred_id, **kwargs):
"""Update state for a specific prediction."""
if pred_id not in self.states:
return None
for key, value in kwargs.items():
self.states[pred_id][key] = value
def get(self, pred_id):
return self.states.get(pred_id, {})
def clear_all(self):
self.states = {}
self._next_pred_id = 0
def relative_l1_distance(last_tensor, current_tensor):
l1_distance = torch.abs(last_tensor.to(current_tensor.device) - current_tensor).mean()
norm = torch.abs(last_tensor).mean()
relative_l1_distance = l1_distance / norm
return relative_l1_distance.to(torch.float32).to(current_tensor.device)
def cache_report(transformer, cache_args):
cache_type = cache_args["cache_type"]
states = (
transformer.teacache_state.states if cache_type == "TeaCache" else
transformer.magcache_state.states if cache_type == "MagCache" else
transformer.easycache_state.states if cache_type == "EasyCache" else
None
)
state_names = {
0: "conditional",
1: "unconditional"
}
for pred_id, state in states.items():
name = state_names.get(pred_id, f"prediction_{pred_id}")
if 'skipped_steps' in state:
log.info(f"{cache_type} skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}")
transformer.teacache_state.clear_all()
transformer.magcache_state.clear_all()
transformer.easycache_state.clear_all()
del states