mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
91 lines
3.6 KiB
Python
91 lines
3.6 KiB
Python
# copied from https://github.com/Birch-san/sdxl-play/blob/main/src/attn/natten_attn_processor.py
|
|
|
|
import os
|
|
from typing import Optional
|
|
from diffusers.models.attention import Attention
|
|
import torch
|
|
from torch.nn import Linear
|
|
from einops import rearrange
|
|
from installer import install, log
|
|
|
|
|
|
def init():
|
|
try:
|
|
os.environ['NATTEN_CUDA_ARCH'] = '8.0;8.6'
|
|
install('natten')
|
|
import natten
|
|
return natten
|
|
except Exception as e:
|
|
log.error(f'Init natten: {e}')
|
|
return None
|
|
|
|
|
|
def fuse_qkv(attn: Attention) -> None:
|
|
has_bias = attn.to_q.bias is not None
|
|
qkv = Linear(in_features=attn.to_q.in_features, out_features=attn.to_q.out_features*3, bias=has_bias, dtype=attn.to_q.weight.dtype, device=attn.to_q.weight.device)
|
|
qkv.weight.data.copy_(torch.cat([attn.to_q.weight.data * attn.scale, attn.to_k.weight.data, attn.to_v.weight.data]))
|
|
if has_bias:
|
|
qkv.bias.data.copy_(torch.cat([attn.to_q.bias.data * attn.scale, attn.to_k.bias.data, attn.to_v.bias.data]))
|
|
setattr(attn, 'qkv', qkv) # noqa: B010
|
|
del attn.to_q, attn.to_k, attn.to_v
|
|
|
|
|
|
def fuse_vae_qkv(vae) -> None:
|
|
for attn in [*vae.encoder.mid_block.attentions, *vae.decoder.mid_block.attentions]:
|
|
fuse_qkv(attn)
|
|
|
|
|
|
class NattenAttnProcessor:
|
|
kernel_size: int
|
|
|
|
def __init__(self, kernel_size: int):
|
|
self.kernel_size = kernel_size
|
|
|
|
def __call__(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states: torch.FloatTensor,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.BoolTensor] = None,
|
|
temb: Optional[torch.FloatTensor] = None,
|
|
):
|
|
import natten
|
|
assert hasattr(attn, 'qkv'), "Did not find property qkv on attn. Expected you to fuse its q_proj, k_proj, v_proj weights and biases beforehand, and multiply attn.scale into the q weights and bias."
|
|
residual = hidden_states
|
|
if attn.spatial_norm is not None:
|
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
# assumes MHA (as opposed to GQA)
|
|
inner_dim: int = attn.qkv.out_features // 3
|
|
if attention_mask is not None:
|
|
raise ValueError("No mask customization for neighbourhood attention; the mask is already complicated enough as it is")
|
|
if encoder_hidden_states is not None:
|
|
raise ValueError("NATTEN cannot be used for cross-attention. I think.")
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states)
|
|
hidden_states = rearrange(hidden_states, '... c h w -> ... h w c')
|
|
qkv = attn.qkv(hidden_states)
|
|
# assumes MHA (as opposed to GQA)
|
|
q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=inner_dim)
|
|
qk = natten.functional.na2d_qk(q, k, self.kernel_size, 1) # natten2dqk
|
|
a = torch.softmax(qk, dim=-1)
|
|
hidden_states = natten.functional.na2d_av(a, v, self.kernel_size, 1) # natten2dav
|
|
hidden_states = rearrange(hidden_states, "n nh h w e -> n h w (nh e)")
|
|
linear_proj, dropout = attn.to_out
|
|
hidden_states = linear_proj(hidden_states)
|
|
hidden_states = dropout(hidden_states)
|
|
hidden_states = rearrange(hidden_states, '... h w c -> ... c h w')
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
return hidden_states
|
|
|
|
|
|
def enable_natten(pipe):
|
|
if not hasattr(pipe, 'vae'):
|
|
return
|
|
natten = init()
|
|
kernel_size = 17
|
|
if natten is not None:
|
|
log.info(f'VAE natten: version={natten.__version__} kernel={kernel_size}')
|
|
fuse_vae_qkv(pipe.vae)
|
|
pipe.vae.set_attn_processor(NattenAttnProcessor(kernel_size=kernel_size))
|