1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/transformer_cache.py
Vladimir Mandic 2a18890235 update changelog and cleanup
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-03-27 11:53:03 -04:00

52 lines
3.2 KiB
Python

import os
import diffusers
from modules import shared, errors
debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
def set_cache(faster_cache=None, pyramid_attention_broadcast=None):
if not shared.sd_loaded or not hasattr(shared.sd_model, 'transformer'):
return
faster_cache = faster_cache if faster_cache is not None else shared.opts.faster_cache_enabled
pyramid_attention_broadcast = pyramid_attention_broadcast if pyramid_attention_broadcast is not None else shared.opts.pab_enabled
if (not faster_cache) and (not pyramid_attention_broadcast):
return
if (not hasattr(shared.sd_model.transformer, 'enable_cache')) or (not hasattr(shared.sd_model.transformer, 'disable_cache')):
shared.log.debug(f'Transformer cache: cls={shared.sd_model.transformer.__class__.__name__} fc={faster_cache} pab={pyramid_attention_broadcast} not supported')
return
try:
if faster_cache: # https://github.com/huggingface/diffusers/pull/10163
distilled = shared.opts.fc_guidance_distilled or shared.sd_model_type == 'f1'
config = diffusers.FasterCacheConfig(
spatial_attention_block_skip_range=shared.opts.fc_spacial_skip_range,
spatial_attention_timestep_skip_range=(int(shared.opts.fc_spacial_skip_start), int(shared.opts.fc_spacial_skip_end)),
unconditional_batch_skip_range=shared.opts.fc_uncond_skip_range,
unconditional_batch_timestep_skip_range=(int(shared.opts.fc_uncond_skip_start), int(shared.opts.fc_uncond_skip_end)),
attention_weight_callback=lambda _: shared.opts.fc_attention_weight,
tensor_format=shared.opts.fc_tensor_format, # TODO fc: autodetect tensor format based on model
is_guidance_distilled=distilled, # TODO fc: autodetect distilled based on model
current_timestep_callback=lambda: shared.sd_model.current_timestep,
)
shared.sd_model.transformer.disable_cache()
shared.sd_model.transformer.enable_cache(config)
shared.log.debug(f'Transformer cache: type={config.__class__.__name__}')
debug(f'Transformer cache: {vars(config)}')
elif pyramid_attention_broadcast: # https://github.com/huggingface/diffusers/pull/9562
config = diffusers.PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=shared.opts.pab_spacial_skip_range,
spatial_attention_timestep_skip_range=(int(shared.opts.pab_spacial_skip_start), int(shared.opts.pab_spacial_skip_end)),
current_timestep_callback=lambda: shared.sd_model.current_timestep,
)
shared.sd_model.transformer.disable_cache()
shared.sd_model.transformer.enable_cache(config)
shared.log.debug(f'Transformer cache: type={config.__class__.__name__}')
debug(f'Transformer cache: {vars(config)}')
else:
debug('Transformer cache: not enabled')
shared.sd_model.transformer.disable_cache()
except Exception as e:
shared.log.error(f'Transformer cache: {e}')
errors.display(e, 'Transformer cache')