mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
78 lines
3.1 KiB
Python
78 lines
3.1 KiB
Python
from modules import shared
|
|
|
|
|
|
def apply_token_merging(sd_model):
|
|
current_tome = getattr(sd_model, 'applied_tome', 0)
|
|
current_todo = getattr(sd_model, 'applied_todo', 0)
|
|
|
|
if shared.opts.token_merging_method == 'ToMe' and shared.opts.tome_ratio > 0:
|
|
if current_tome == shared.opts.tome_ratio:
|
|
return
|
|
if shared.opts.hypertile_unet_enabled and not shared.cmd_opts.experimental:
|
|
shared.log.warning('Token merging not supported with HyperTile for UNet')
|
|
return
|
|
try:
|
|
import installer
|
|
installer.install('tomesd', 'tomesd', ignore=False)
|
|
import tomesd
|
|
tomesd.apply_patch(
|
|
sd_model,
|
|
ratio=shared.opts.tome_ratio,
|
|
use_rand=False, # can cause issues with some samplers
|
|
merge_attn=True,
|
|
merge_crossattn=False,
|
|
merge_mlp=False
|
|
)
|
|
shared.log.info(f'Applying ToMe: ratio={shared.opts.tome_ratio}')
|
|
sd_model.applied_tome = shared.opts.tome_ratio
|
|
except Exception:
|
|
shared.log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
|
|
else:
|
|
sd_model.applied_tome = 0
|
|
|
|
if shared.opts.token_merging_method == 'ToDo' and shared.opts.todo_ratio > 0:
|
|
if current_todo == shared.opts.todo_ratio:
|
|
return
|
|
if shared.opts.hypertile_unet_enabled and not shared.cmd_opts.experimental:
|
|
shared.log.warning('Token merging not supported with HyperTile for UNet')
|
|
return
|
|
try:
|
|
from modules.todo.todo_utils import patch_attention_proc
|
|
token_merge_args = {
|
|
"ratio": shared.opts.todo_ratio,
|
|
"merge_tokens": "keys/values",
|
|
"merge_method": "downsample",
|
|
"downsample_method": "nearest",
|
|
"downsample_factor": 2,
|
|
"timestep_threshold_switch": 0.0,
|
|
"timestep_threshold_stop": 0.0,
|
|
"downsample_factor_level_2": 1,
|
|
"ratio_level_2": 0.0,
|
|
}
|
|
patch_attention_proc(sd_model.unet, token_merge_args=token_merge_args)
|
|
shared.log.info(f'Applying ToDo: ratio={shared.opts.todo_ratio}')
|
|
sd_model.applied_todo = shared.opts.todo_ratio
|
|
except Exception:
|
|
shared.log.warning(f'Token merging not supported: pipeline={sd_model.__class__.__name__}')
|
|
else:
|
|
sd_model.applied_todo = 0
|
|
|
|
|
|
def remove_token_merging(sd_model):
|
|
current_tome = getattr(sd_model, 'applied_tome', 0)
|
|
current_todo = getattr(sd_model, 'applied_todo', 0)
|
|
try:
|
|
if current_tome > 0:
|
|
import tomesd
|
|
tomesd.remove_patch(sd_model)
|
|
sd_model.applied_tome = 0
|
|
except Exception:
|
|
pass
|
|
try:
|
|
if current_todo > 0:
|
|
from modules.todo.todo_utils import remove_patch
|
|
remove_patch(sd_model)
|
|
sd_model.applied_todo = 0
|
|
except Exception:
|
|
pass
|