You've already forked ComfyUI-WanVideoWrapper
mirror of
https://github.com/kijai/ComfyUI-WanVideoWrapper.git
synced 2026-01-28 12:20:55 +03:00
158 lines
5.9 KiB
Python
158 lines
5.9 KiB
Python
from ..utils import log
|
|
import torch
|
|
|
|
def set_transformer_cache_method(transformer, timesteps, cache_args=None):
|
|
transformer.cache_device = cache_args["cache_device"]
|
|
if cache_args["cache_type"] == "TeaCache":
|
|
log.info(f"TeaCache: Using cache device: {transformer.cache_device}")
|
|
transformer.teacache_state.clear_all()
|
|
transformer.enable_teacache = True
|
|
transformer.rel_l1_thresh = cache_args["rel_l1_thresh"]
|
|
transformer.teacache_start_step = cache_args["start_step"]
|
|
transformer.teacache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
|
|
transformer.teacache_use_coefficients = cache_args["use_coefficients"]
|
|
transformer.teacache_mode = cache_args["mode"]
|
|
elif cache_args["cache_type"] == "MagCache":
|
|
log.info(f"MagCache: Using cache device: {transformer.cache_device}")
|
|
transformer.magcache_state.clear_all()
|
|
transformer.enable_magcache = True
|
|
transformer.magcache_start_step = cache_args["start_step"]
|
|
transformer.magcache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
|
|
transformer.magcache_thresh = cache_args["magcache_thresh"]
|
|
transformer.magcache_K = cache_args["magcache_K"]
|
|
elif cache_args["cache_type"] == "EasyCache":
|
|
log.info(f"EasyCache: Using cache device: {transformer.cache_device}")
|
|
transformer.easycache_state.clear_all()
|
|
transformer.enable_easycache = True
|
|
transformer.easycache_start_step = cache_args["start_step"]
|
|
transformer.easycache_end_step = len(timesteps)-1 if cache_args["end_step"] == -1 else cache_args["end_step"]
|
|
transformer.easycache_thresh = cache_args["easycache_thresh"]
|
|
return transformer
|
|
|
|
class TeaCacheState:
|
|
def __init__(self, cache_device='cpu'):
|
|
self.cache_device = cache_device
|
|
self.states = {}
|
|
self._next_pred_id = 0
|
|
|
|
def new_prediction(self, cache_device='cpu'):
|
|
"""Create new prediction state and return its ID"""
|
|
self.cache_device = cache_device
|
|
pred_id = self._next_pred_id
|
|
self._next_pred_id += 1
|
|
self.states[pred_id] = {
|
|
'previous_residual': None,
|
|
'accumulated_rel_l1_distance': 0,
|
|
'previous_modulated_input': None,
|
|
'skipped_steps': [],
|
|
}
|
|
return pred_id
|
|
|
|
def update(self, pred_id, **kwargs):
|
|
"""Update state for specific prediction"""
|
|
if pred_id not in self.states:
|
|
return None
|
|
for key, value in kwargs.items():
|
|
self.states[pred_id][key] = value
|
|
|
|
def get(self, pred_id):
|
|
return self.states.get(pred_id, {})
|
|
|
|
def clear_all(self):
|
|
self.states = {}
|
|
self._next_pred_id = 0
|
|
|
|
class MagCacheState:
|
|
def __init__(self, cache_device='cpu'):
|
|
self.cache_device = cache_device
|
|
self.states = {}
|
|
self._next_pred_id = 0
|
|
|
|
def new_prediction(self, cache_device='cpu'):
|
|
"""Create new prediction state and return its ID"""
|
|
self.cache_device = cache_device
|
|
pred_id = self._next_pred_id
|
|
self._next_pred_id += 1
|
|
self.states[pred_id] = {
|
|
'residual_cache': None,
|
|
'accumulated_ratio': 1.0,
|
|
'accumulated_steps': 0,
|
|
'accumulated_err': 0,
|
|
'skipped_steps': [],
|
|
}
|
|
return pred_id
|
|
|
|
def update(self, pred_id, **kwargs):
|
|
"""Update state for specific prediction"""
|
|
if pred_id not in self.states:
|
|
return None
|
|
for key, value in kwargs.items():
|
|
self.states[pred_id][key] = value
|
|
|
|
def get(self, pred_id):
|
|
return self.states.get(pred_id, {})
|
|
|
|
def clear_all(self):
|
|
self.states = {}
|
|
self._next_pred_id = 0
|
|
|
|
class EasyCacheState:
|
|
def __init__(self, cache_device='cpu'):
|
|
self.cache_device = cache_device
|
|
self.states = {}
|
|
self._next_pred_id = 0
|
|
|
|
def new_prediction(self, cache_device='cpu'):
|
|
"""Create a new prediction state and return its ID."""
|
|
self.cache_device = cache_device
|
|
pred_id = self._next_pred_id
|
|
self._next_pred_id += 1
|
|
self.states[pred_id] = {
|
|
'previous_raw_input': None,
|
|
'previous_raw_output': None,
|
|
'cache': None,
|
|
'accumulated_error': 0.0,
|
|
'skipped_steps': [],
|
|
}
|
|
return pred_id
|
|
|
|
def update(self, pred_id, **kwargs):
|
|
"""Update state for a specific prediction."""
|
|
if pred_id not in self.states:
|
|
return None
|
|
for key, value in kwargs.items():
|
|
self.states[pred_id][key] = value
|
|
|
|
def get(self, pred_id):
|
|
return self.states.get(pred_id, {})
|
|
|
|
def clear_all(self):
|
|
self.states = {}
|
|
self._next_pred_id = 0
|
|
|
|
def relative_l1_distance(last_tensor, current_tensor):
|
|
l1_distance = torch.abs(last_tensor.to(current_tensor.device) - current_tensor).mean()
|
|
norm = torch.abs(last_tensor).mean()
|
|
relative_l1_distance = l1_distance / norm
|
|
return relative_l1_distance.to(torch.float32).to(current_tensor.device)
|
|
|
|
def cache_report(transformer, cache_args):
|
|
cache_type = cache_args["cache_type"]
|
|
states = (
|
|
transformer.teacache_state.states if cache_type == "TeaCache" else
|
|
transformer.magcache_state.states if cache_type == "MagCache" else
|
|
transformer.easycache_state.states if cache_type == "EasyCache" else
|
|
None
|
|
)
|
|
state_names = {
|
|
0: "conditional",
|
|
1: "unconditional"
|
|
}
|
|
for pred_id, state in states.items():
|
|
name = state_names.get(pred_id, f"prediction_{pred_id}")
|
|
if 'skipped_steps' in state:
|
|
log.info(f"{cache_type} skipped: {len(state['skipped_steps'])} {name} steps: {state['skipped_steps']}")
|
|
transformer.teacache_state.clear_all()
|
|
transformer.magcache_state.clear_all()
|
|
transformer.easycache_state.clear_all()
|
|
del states |