mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
128 lines
3.8 KiB
Python
128 lines
3.8 KiB
Python
from typing import List, Generic, TypeVar
|
|
import functools
|
|
import itertools
|
|
|
|
from diffusers import UNet2DConditionModel
|
|
from diffusers.models.attention_processor import Attention
|
|
import torch.nn as nn
|
|
|
|
|
|
__all__ = ['ObjectHooker', 'ModuleLocator', 'AggregateHooker', 'UNetCrossAttentionLocator']
|
|
|
|
|
|
ModuleType = TypeVar('ModuleType')
|
|
ModuleListType = TypeVar('ModuleListType', bound=List)
|
|
|
|
|
|
class ModuleLocator(Generic[ModuleType]):
|
|
def locate(self, model: nn.Module) -> List[ModuleType]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class ObjectHooker(Generic[ModuleType]):
|
|
def __init__(self, module: ModuleType):
|
|
self.module: ModuleType = module
|
|
self.hooked = False
|
|
self.old_state = {}
|
|
|
|
def __enter__(self):
|
|
self.hook()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.unhook()
|
|
|
|
def hook(self):
|
|
if self.hooked:
|
|
raise RuntimeError('Already hooked module')
|
|
|
|
self.old_state = {}
|
|
self.hooked = True
|
|
self._hook_impl()
|
|
|
|
return self
|
|
|
|
def unhook(self):
|
|
if not self.hooked:
|
|
raise RuntimeError('Module is not hooked')
|
|
|
|
for k, v in self.old_state.items():
|
|
if k.startswith('old_fn_'):
|
|
setattr(self.module, k[7:], v)
|
|
|
|
self.hooked = False
|
|
self._unhook_impl()
|
|
|
|
return self
|
|
|
|
def monkey_patch(self, fn_name, fn, strict: bool = True):
|
|
try:
|
|
self.old_state[f'old_fn_{fn_name}'] = getattr(self.module, fn_name)
|
|
setattr(self.module, fn_name, functools.partial(fn, self.module))
|
|
except AttributeError:
|
|
if strict:
|
|
raise
|
|
|
|
def monkey_super(self, fn_name, *args, **kwargs):
|
|
return self.old_state[f'old_fn_{fn_name}'](*args, **kwargs)
|
|
|
|
def _hook_impl(self):
|
|
raise NotImplementedError
|
|
|
|
def _unhook_impl(self):
|
|
pass
|
|
|
|
|
|
class AggregateHooker(ObjectHooker[ModuleListType]):
|
|
def _hook_impl(self):
|
|
for h in self.module:
|
|
h.hook()
|
|
|
|
def _unhook_impl(self):
|
|
for h in self.module:
|
|
h.unhook()
|
|
|
|
def register_hook(self, hook: ObjectHooker):
|
|
self.module.append(hook)
|
|
|
|
|
|
class UNetCrossAttentionLocator(ModuleLocator[Attention]):
|
|
def __init__(self, restrict: bool = None, locate_middle_block: bool = False):
|
|
self.restrict = restrict
|
|
self.layer_names = []
|
|
self.locate_middle_block = locate_middle_block
|
|
|
|
def locate(self, model: UNet2DConditionModel) -> List[Attention]:
|
|
"""
|
|
Locate all cross-attention modules in a UNet2DConditionModel.
|
|
|
|
Args:
|
|
model (`UNet2DConditionModel`): The model to locate the cross-attention modules in.
|
|
|
|
Returns:
|
|
`List[Attention]`: The list of cross-attention modules.
|
|
"""
|
|
self.layer_names.clear()
|
|
blocks_list = []
|
|
up_names = ['up'] * len(model.up_blocks)
|
|
down_names = ['down'] * len(model.down_blocks)
|
|
|
|
for unet_block, name in itertools.chain(
|
|
zip(model.up_blocks, up_names),
|
|
zip(model.down_blocks, down_names),
|
|
zip([model.mid_block], ['mid']) if self.locate_middle_block else [],
|
|
):
|
|
if 'CrossAttn' in unet_block.__class__.__name__:
|
|
blocks = []
|
|
|
|
for spatial_transformer in unet_block.attentions:
|
|
for transformer_block in spatial_transformer.transformer_blocks:
|
|
blocks.append(transformer_block.attn2)
|
|
|
|
blocks = [b for idx, b in enumerate(blocks) if self.restrict is None or idx in self.restrict]
|
|
names = [f'{name}-attn-{i}' for i in range(len(blocks)) if self.restrict is None or i in self.restrict]
|
|
blocks_list.extend(blocks)
|
|
self.layer_names.extend(names)
|
|
|
|
return blocks_list
|