1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/cachedit.py
2025-09-23 15:25:13 -04:00

65 lines
2.4 KiB
Python

import os
from installer import install
from modules import shared
def apply_cache_dit(pipe):
if not shared.opts.cache_dit_enabled:
return
install('git+https://github.com/vipshop/cache-dit', 'cache_dit')
os.environ.setdefault("CACHE_DIT_LOG_LEVEL", "error")
try:
import cache_dit
except Exception as e:
shared.log.error(f'Cache-DIT: {e}')
return
_, supported = cache_dit.supported_pipelines()
supported = [s.replace('*', '') for s in supported]
if not any(pipe.__class__.__name__.startswith(s) for s in supported):
shared.log.error(f'Cache-DiT: pipeline={pipe.__class__.__name__} unsupported')
return
if getattr(pipe, 'has_cache_dit', False):
unapply_cache_dir(pipe)
config_args = {}
if shared.opts.cache_dit_fcompute >= 0:
config_args['Fn_compute_blocks'] = int(shared.opts.cache_dit_fcompute)
if shared.opts.cache_dit_bcompute >= 0:
config_args['Bn_compute_blocks'] = int(shared.opts.cache_dit_bcompute)
if shared.opts.cache_dit_threshold >= 0:
config_args['residual_diff_threshold'] = float(shared.opts.cache_dit_threshold)
if shared.opts.cache_dit_warmup >= 0:
config_args['max_warmup_steps'] = int(shared.opts.cache_dit_warmup)
cache_config = cache_dit.BasicCacheConfig(**config_args)
if shared.opts.cache_dit_calibrator == "TaylorSeer":
calibrator_config = cache_dit.TaylorSeerCalibratorConfig(taylorseer_order=1)
elif shared.opts.cache_dit_calibrator == "FoCa":
calibrator_config = cache_dit.FoCaCalibratorConfig()
else:
calibrator_config = None
shared.log.info(f'Apply Cache-DiT: config="{cache_config.strify()}" calibrator="{calibrator_config.strify() if calibrator_config else "None"}"')
try:
cache_dit.enable_cache(
pipe,
cache_config=cache_config,
calibrator_config=calibrator_config,
)
shared.sd_model.has_cache_dit = True
except Exception as e:
shared.log.error(f'Cache-DiT: {e}')
return
def unapply_cache_dir(pipe):
if not shared.opts.cache_dit_enabled or not getattr(pipe, 'has_cache_dit', False):
return
try:
import cache_dit
# stats = cache_dit.summary(pipe)
# shared.log.critical(f'Unapply Cache-DiT: {stats}')
cache_dit.disable_cache(pipe)
pipe.has_cache_dit = False
except Exception:
return