1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/scripts/daam/hook.py
Vladimir Mandic ae25cb8880 linting
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-09-25 14:33:21 -04:00

128 lines
3.8 KiB
Python

from typing import List, Generic, TypeVar
import functools
import itertools
from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import Attention
import torch.nn as nn
__all__ = ['ObjectHooker', 'ModuleLocator', 'AggregateHooker', 'UNetCrossAttentionLocator']
ModuleType = TypeVar('ModuleType')
ModuleListType = TypeVar('ModuleListType', bound=List)
class ModuleLocator(Generic[ModuleType]):
def locate(self, model: nn.Module) -> List[ModuleType]:
raise NotImplementedError
class ObjectHooker(Generic[ModuleType]):
def __init__(self, module: ModuleType):
self.module: ModuleType = module
self.hooked = False
self.old_state = {}
def __enter__(self):
self.hook()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.unhook()
def hook(self):
if self.hooked:
raise RuntimeError('Already hooked module')
self.old_state = {}
self.hooked = True
self._hook_impl()
return self
def unhook(self):
if not self.hooked:
raise RuntimeError('Module is not hooked')
for k, v in self.old_state.items():
if k.startswith('old_fn_'):
setattr(self.module, k[7:], v)
self.hooked = False
self._unhook_impl()
return self
def monkey_patch(self, fn_name, fn, strict: bool = True):
try:
self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name)
setattr(self.module, fn_name, functools.partial(fn, self.module))
except AttributeError:
if strict:
raise
def monkey_super(self, fn_name, *args, **kwargs):
return self.old_state[f'old_fn_{fn_name}'](*args, **kwargs)
def _hook_impl(self):
raise NotImplementedError
def _unhook_impl(self):
pass
class AggregateHooker(ObjectHooker[ModuleListType]):
def _hook_impl(self):
for h in self.module:
h.hook()
def _unhook_impl(self):
for h in self.module:
h.unhook()
def register_hook(self, hook: ObjectHooker):
self.module.append(hook)
class UNetCrossAttentionLocator(ModuleLocator[Attention]):
def __init__(self, restrict: bool = None, locate_middle_block: bool = False):
self.restrict = restrict
self.layer_names = []
self.locate_middle_block = locate_middle_block
def locate(self, model: UNet2DConditionModel) -> List[Attention]:
"""
Locate all cross-attention modules in a UNet2DConditionModel.
Args:
model (`UNet2DConditionModel`): The model to locate the cross-attention modules in.
Returns:
`List[Attention]`: The list of cross-attention modules.
"""
self.layer_names.clear()
blocks_list = []
up_names = ['up'] * len(model.up_blocks)
down_names = ['down'] * len(model.down_blocks)
for unet_block, name in itertools.chain(
zip(model.up_blocks, up_names),
zip(model.down_blocks, down_names),
zip([model.mid_block], ['mid']) if self.locate_middle_block else [],
):
if 'CrossAttn' in unet_block.__class__.__name__:
blocks = []
for spatial_transformer in unet_block.attentions:
for transformer_block in spatial_transformer.transformer_blocks:
blocks.append(transformer_block.attn2)
blocks = [b for idx, b in enumerate(blocks) if self.restrict is None or idx in self.restrict]
names = [f'{name}-attn-{i}' for i in range(len(blocks)) if self.restrict is None or i in self.restrict]
blocks_list.extend(blocks)
self.layer_names.extend(names)
return blocks_list