1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/modules/todo/__init__.py
2024-04-28 20:29:28 -04:00

43 lines
1.4 KiB
Python

from modules.todo.todo_utils import patch_attention_proc
def apply_todo(model, p, method='todo'):
mp = p.height * p.width / 1024 / 1024
if mp < 1.0: # 512px
downsample_factor = 2
ratio = 0.38
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif mp < 1.1: # 1024+
downsample_factor = 2
ratio = 0.75
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif mp < 2.3:
downsample_factor = 3
ratio = 0.89
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
elif mp < 8:
downsample_factor = 4
ratio = 0.9375
downsample_factor_level_2 = 1
ratio_level_2 = 0.0
else:
return
merge_method = "downsample" if method == "todo" else "similarity"
merge_tokens = "keys/values" if method == "todo" else "all"
token_merge_args = {
"ratio": ratio,
"merge_tokens": merge_tokens,
"merge_method": merge_method,
"downsample_method": "nearest",
"downsample_factor": downsample_factor,
"timestep_threshold_switch": 0.0,
"timestep_threshold_stop": 0.0,
"downsample_factor_level_2": downsample_factor_level_2,
"ratio_level_2": ratio_level_2
}
patch_attention_proc(model.unet, token_merge_args=token_merge_args)