1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/token_merge.py
Vladimir Mandic 6760632f38 major model load refactor
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2024-10-26 13:22:29 -04:00

78 lines
3.1 KiB
Python

from modules import shared
def apply_token_merging(sd_model):
current_tome = getattr(sd_model, 'applied_tome', 0)
current_todo = getattr(sd_model, 'applied_todo', 0)
if shared.opts.token_merging_method == 'ToMe' and shared.opts.tome_ratio > 0:
if current_tome == shared.opts.tome_ratio:
return
if shared.opts.hypertile_unet_enabled and not shared.cmd_opts.experimental:
shared.log.warning('Token merging not supported with HyperTile for UNet')
return
try:
import installer
installer.install('tomesd', 'tomesd', ignore=False)
import tomesd
tomesd.apply_patch(
sd_model,
ratio=shared.opts.tome_ratio,
use_rand=False, # can cause issues with some samplers
merge_attn=True,
merge_crossattn=False,
merge_mlp=False
)
shared.log.info(f'Applying ToMe: ratio={shared.opts.tome_ratio}')
sd_model.applied_tome = shared.opts.tome_ratio
except Exception:
shared.log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
else:
sd_model.applied_tome = 0
if shared.opts.token_merging_method == 'ToDo' and shared.opts.todo_ratio > 0:
if current_todo == shared.opts.todo_ratio:
return
if shared.opts.hypertile_unet_enabled and not shared.cmd_opts.experimental:
shared.log.warning('Token merging not supported with HyperTile for UNet')
return
try:
from modules.todo.todo_utils import patch_attention_proc
token_merge_args = {
"ratio": shared.opts.todo_ratio,
"merge_tokens": "keys/values",
"merge_method": "downsample",
"downsample_method": "nearest",
"downsample_factor": 2,
"timestep_threshold_switch": 0.0,
"timestep_threshold_stop": 0.0,
"downsample_factor_level_2": 1,
"ratio_level_2": 0.0,
}
patch_attention_proc(sd_model.unet, token_merge_args=token_merge_args)
shared.log.info(f'Applying ToDo: ratio={shared.opts.todo_ratio}')
sd_model.applied_todo = shared.opts.todo_ratio
except Exception:
shared.log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
else:
sd_model.applied_todo = 0
def remove_token_merging(sd_model):
current_tome = getattr(sd_model, 'applied_tome', 0)
current_todo = getattr(sd_model, 'applied_todo', 0)
try:
if current_tome > 0:
import tomesd
tomesd.remove_patch(sd_model)
sd_model.applied_tome = 0
except Exception:
pass
try:
if current_todo > 0:
from modules.todo.todo_utils import remove_patch
remove_patch(sd_model)
sd_model.applied_todo = 0
except Exception:
pass