mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
65 lines
2.4 KiB
Python
65 lines
2.4 KiB
Python
import os
|
|
from installer import install
|
|
from modules import shared
|
|
|
|
|
|
def apply_cache_dit(pipe):
|
|
if not shared.opts.cache_dit_enabled:
|
|
return
|
|
install('git+https://github.com/vipshop/cache-dit', 'cache_dit')
|
|
os.environ.setdefault("CACHE_DIT_LOG_LEVEL", "error")
|
|
try:
|
|
import cache_dit
|
|
except Exception as e:
|
|
shared.log.error(f'Cache-DIT: {e}')
|
|
return
|
|
_, supported = cache_dit.supported_pipelines()
|
|
supported = [s.replace('*', '') for s in supported]
|
|
if not any(pipe.__class__.__name__.startswith(s) for s in supported):
|
|
shared.log.error(f'Cache-DiT: pipeline={pipe.__class__.__name__} unsupported')
|
|
return
|
|
|
|
if getattr(pipe, 'has_cache_dit', False):
|
|
unapply_cache_dir(pipe)
|
|
|
|
config_args = {}
|
|
if shared.opts.cache_dit_fcompute >= 0:
|
|
config_args['Fn_compute_blocks'] = int(shared.opts.cache_dit_fcompute)
|
|
if shared.opts.cache_dit_bcompute >= 0:
|
|
config_args['Bn_compute_blocks'] = int(shared.opts.cache_dit_bcompute)
|
|
if shared.opts.cache_dit_threshold >= 0:
|
|
config_args['residual_diff_threshold'] = float(shared.opts.cache_dit_threshold)
|
|
if shared.opts.cache_dit_warmup >= 0:
|
|
config_args['max_warmup_steps'] = int(shared.opts.cache_dit_warmup)
|
|
cache_config = cache_dit.BasicCacheConfig(**config_args)
|
|
if shared.opts.cache_dit_calibrator == "TaylorSeer":
|
|
calibrator_config = cache_dit.TaylorSeerCalibratorConfig(taylorseer_order=1)
|
|
elif shared.opts.cache_dit_calibrator == "FoCa":
|
|
calibrator_config = cache_dit.FoCaCalibratorConfig()
|
|
else:
|
|
calibrator_config = None
|
|
shared.log.info(f'Apply Cache-DiT: config="{cache_config.strify()}" calibrator="{calibrator_config.strify() if calibrator_config else "None"}"')
|
|
try:
|
|
cache_dit.enable_cache(
|
|
pipe,
|
|
cache_config=cache_config,
|
|
calibrator_config=calibrator_config,
|
|
)
|
|
shared.sd_model.has_cache_dit = True
|
|
except Exception as e:
|
|
shared.log.error(f'Cache-DiT: {e}')
|
|
return
|
|
|
|
|
|
def unapply_cache_dir(pipe):
|
|
if not shared.opts.cache_dit_enabled or not getattr(pipe, 'has_cache_dit', False):
|
|
return
|
|
try:
|
|
import cache_dit
|
|
# stats = cache_dit.summary(pipe)
|
|
# shared.log.critical(f'Unapply Cache-DiT: {stats}')
|
|
cache_dit.disable_cache(pipe)
|
|
pipe.has_cache_dit = False
|
|
except Exception:
|
|
return
|