mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
52 lines
3.2 KiB
Python
52 lines
3.2 KiB
Python
import os
|
|
import diffusers
|
|
from modules import shared, errors
|
|
|
|
|
|
debug = shared.log.trace if os.environ.get('SD_PROCESS_DEBUG', None) is not None else lambda *args, **kwargs: None
|
|
|
|
|
|
def set_cache(faster_cache=None, pyramid_attention_broadcast=None):
|
|
if not shared.sd_loaded or not hasattr(shared.sd_model, 'transformer'):
|
|
return
|
|
faster_cache = faster_cache if faster_cache is not None else shared.opts.faster_cache_enabled
|
|
pyramid_attention_broadcast = pyramid_attention_broadcast if pyramid_attention_broadcast is not None else shared.opts.pab_enabled
|
|
if (not faster_cache) and (not pyramid_attention_broadcast):
|
|
return
|
|
if (not hasattr(shared.sd_model.transformer, 'enable_cache')) or (not hasattr(shared.sd_model.transformer, 'disable_cache')):
|
|
shared.log.debug(f'Transformer cache: cls={shared.sd_model.transformer.__class__.__name__} fc={faster_cache} pab={pyramid_attention_broadcast} not supported')
|
|
return
|
|
try:
|
|
if faster_cache: # https://github.com/huggingface/diffusers/pull/10163
|
|
distilled = shared.opts.fc_guidance_distilled or shared.sd_model_type == 'f1'
|
|
config = diffusers.FasterCacheConfig(
|
|
spatial_attention_block_skip_range=shared.opts.fc_spacial_skip_range,
|
|
spatial_attention_timestep_skip_range=(int(shared.opts.fc_spacial_skip_start), int(shared.opts.fc_spacial_skip_end)),
|
|
unconditional_batch_skip_range=shared.opts.fc_uncond_skip_range,
|
|
unconditional_batch_timestep_skip_range=(int(shared.opts.fc_uncond_skip_start), int(shared.opts.fc_uncond_skip_end)),
|
|
attention_weight_callback=lambda _: shared.opts.fc_attention_weight,
|
|
tensor_format=shared.opts.fc_tensor_format, # TODO fc: autodetect tensor format based on model
|
|
is_guidance_distilled=distilled, # TODO fc: autodetect distilled based on model
|
|
current_timestep_callback=lambda: shared.sd_model.current_timestep,
|
|
)
|
|
shared.sd_model.transformer.disable_cache()
|
|
shared.sd_model.transformer.enable_cache(config)
|
|
shared.log.debug(f'Transformer cache: type={config.__class__.__name__}')
|
|
debug(f'Transformer cache: {vars(config)}')
|
|
elif pyramid_attention_broadcast: # https://github.com/huggingface/diffusers/pull/9562
|
|
config = diffusers.PyramidAttentionBroadcastConfig(
|
|
spatial_attention_block_skip_range=shared.opts.pab_spacial_skip_range,
|
|
spatial_attention_timestep_skip_range=(int(shared.opts.pab_spacial_skip_start), int(shared.opts.pab_spacial_skip_end)),
|
|
current_timestep_callback=lambda: shared.sd_model.current_timestep,
|
|
)
|
|
shared.sd_model.transformer.disable_cache()
|
|
shared.sd_model.transformer.enable_cache(config)
|
|
shared.log.debug(f'Transformer cache: type={config.__class__.__name__}')
|
|
debug(f'Transformer cache: {vars(config)}')
|
|
else:
|
|
debug('Transformer cache: not enabled')
|
|
shared.sd_model.transformer.disable_cache()
|
|
except Exception as e:
|
|
shared.log.error(f'Transformer cache: {e}')
|
|
errors.display(e, 'Transformer cache')
|