mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
235 lines
13 KiB
Python
235 lines
13 KiB
Python
from typing import Optional
|
|
from functools import wraps
|
|
import torch
|
|
from modules import rocm
|
|
from modules.errors import log
|
|
from installer import install, installed
|
|
|
|
|
|
def set_dynamic_attention():
|
|
try:
|
|
sdpa_pre_dyanmic_atten = torch.nn.functional.scaled_dot_product_attention
|
|
from modules.sd_hijack_dynamic_atten import dynamic_scaled_dot_product_attention
|
|
torch.nn.functional.scaled_dot_product_attention = dynamic_scaled_dot_product_attention
|
|
return sdpa_pre_dyanmic_atten
|
|
except Exception as err:
|
|
log.error(f'Torch attention: type="dynamic attention" {err}')
|
|
return None
|
|
|
|
|
|
def set_triton_flash_attention(backend: str):
|
|
try:
|
|
if backend in {"rocm", "zluda"}: # flash_attn_triton_amd only works with AMD
|
|
from modules.flash_attn_triton_amd import interface_fa
|
|
sdpa_pre_triton_flash_atten = torch.nn.functional.scaled_dot_product_attention
|
|
@wraps(sdpa_pre_triton_flash_atten)
|
|
def sdpa_triton_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
|
|
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
|
|
if scale is None:
|
|
scale = query.shape[-1] ** (-0.5)
|
|
head_size_og = query.size(3)
|
|
if head_size_og % 8 != 0:
|
|
query = torch.nn.functional.pad(query, [0, 8 - head_size_og % 8])
|
|
key = torch.nn.functional.pad(key, [0, 8 - head_size_og % 8])
|
|
value = torch.nn.functional.pad(value, [0, 8 - head_size_og % 8])
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
out_padded = torch.zeros_like(query)
|
|
interface_fa.fwd(query, key, value, out_padded, dropout_p, scale, is_causal)
|
|
return out_padded[..., :head_size_og].transpose(1, 2)
|
|
else:
|
|
if enable_gqa:
|
|
kwargs["enable_gqa"] = enable_gqa
|
|
return sdpa_pre_triton_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs)
|
|
torch.nn.functional.scaled_dot_product_attention = sdpa_triton_flash_atten
|
|
log.debug('Torch attention: type="Triton Flash attention"')
|
|
except Exception as err:
|
|
log.error(f'Torch attention: type="Triton Flash attention" {err}')
|
|
|
|
|
|
def set_flex_attention():
|
|
try:
|
|
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
|
|
def flex_attention_causal_mask(b, h, q_idx, kv_idx): # pylint: disable=unused-argument
|
|
return q_idx >= kv_idx
|
|
|
|
sdpa_pre_flex_atten = torch.nn.functional.scaled_dot_product_attention
|
|
@wraps(sdpa_pre_flex_atten)
|
|
def sdpa_flex_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor: # pylint: disable=unused-argument
|
|
score_mod = None
|
|
block_mask = None
|
|
if attn_mask is not None:
|
|
batch_size, num_heads = query.shape[:2]
|
|
seq_len_q = query.shape[-2]
|
|
seq_len_kv = key.shape[-2]
|
|
if attn_mask.ndim == 2:
|
|
attn_mask = attn_mask.view(attn_mask.shape[0], 1, attn_mask.size[1], 1)
|
|
attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
|
|
if attn_mask.dtype == torch.bool:
|
|
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
|
return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
|
|
block_mask = create_block_mask(mask_mod, batch_size, None, seq_len_q, seq_len_kv, device=query.device)
|
|
else:
|
|
def score_mod_fn(score, batch_idx, head_idx, q_idx, kv_idx):
|
|
return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
|
|
score_mod = score_mod_fn
|
|
elif is_causal:
|
|
block_mask = create_block_mask(flex_attention_causal_mask, query.shape[0], query.shape[1], query.shape[-2], key.shape[-2], device=query.device)
|
|
return flex_attention(query, key, value, score_mod=score_mod, block_mask=block_mask, scale=scale, enable_gqa=enable_gqa)
|
|
|
|
torch.nn.functional.scaled_dot_product_attention = sdpa_flex_atten
|
|
log.debug('Torch attention: type="Flex attention"')
|
|
except Exception as err:
|
|
log.error(f'Torch attention: type="Flex attention" {err}')
|
|
|
|
|
|
def set_ck_flash_attention(backend: str, device: torch.device):
|
|
try:
|
|
if backend == "rocm":
|
|
if not installed('flash-attn'):
|
|
log.info('Torch attention: type="Flash attention" building...')
|
|
agent = rocm.Agent(device)
|
|
install(rocm.get_flash_attention_command(agent), reinstall=True)
|
|
else:
|
|
install('flash-attn')
|
|
from flash_attn import flash_attn_func
|
|
sdpa_pre_flash_atten = torch.nn.functional.scaled_dot_product_attention
|
|
@wraps(sdpa_pre_flash_atten)
|
|
def sdpa_flash_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
|
|
if query.shape[-1] <= 128 and attn_mask is None and query.dtype != torch.float32:
|
|
is_unsqueezed = False
|
|
if query.dim() == 3:
|
|
query = query.unsqueeze(0)
|
|
is_unsqueezed = True
|
|
if key.dim() == 3:
|
|
key = key.unsqueeze(0)
|
|
if value.dim() == 3:
|
|
value = value.unsqueeze(0)
|
|
if enable_gqa:
|
|
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
|
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
|
query = query.transpose(1, 2)
|
|
key = key.transpose(1, 2)
|
|
value = value.transpose(1, 2)
|
|
attn_output = flash_attn_func(q=query, k=key, v=value, dropout_p=dropout_p, causal=is_causal, softmax_scale=scale).transpose(1, 2)
|
|
if is_unsqueezed:
|
|
attn_output = attn_output.squeeze(0)
|
|
return attn_output
|
|
else:
|
|
if enable_gqa:
|
|
kwargs["enable_gqa"] = enable_gqa
|
|
return sdpa_pre_flash_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs)
|
|
torch.nn.functional.scaled_dot_product_attention = sdpa_flash_atten
|
|
log.debug('Torch attention: type="Flash attention"')
|
|
except Exception as err:
|
|
log.error(f'Torch attention: type="Flash attention" {err}')
|
|
|
|
|
|
def set_sage_attention(backend: str, device: torch.device):
|
|
try:
|
|
install('sageattention')
|
|
|
|
use_cuda_backend = False
|
|
if (backend == "cuda") and (torch.cuda.get_device_capability(device) == (8, 6)):
|
|
use_cuda_backend = True # Detect GPU architecture - sm86 confirmed to need CUDA backend workaround as Sage Attention + Triton causes NaNs
|
|
try:
|
|
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
|
except Exception:
|
|
use_cuda_backend = False
|
|
|
|
if use_cuda_backend:
|
|
from sageattention import sageattn_qk_int8_pv_fp16_cuda
|
|
def sage_attn_impl(query, key, value, is_causal, scale):
|
|
return sageattn_qk_int8_pv_fp16_cuda(
|
|
q=query, k=key, v=value,
|
|
tensor_layout="HND",
|
|
is_causal=is_causal,
|
|
sm_scale=scale,
|
|
return_lse=False,
|
|
pv_accum_dtype="fp32",
|
|
)
|
|
else:
|
|
from sageattention import sageattn
|
|
def sage_attn_impl(query, key, value, is_causal, scale):
|
|
return sageattn(
|
|
q=query, k=key, v=value,
|
|
attn_mask=None,
|
|
dropout_p=0.0,
|
|
is_causal=is_causal,
|
|
scale=scale,
|
|
)
|
|
|
|
sdpa_pre_sage_atten = torch.nn.functional.scaled_dot_product_attention
|
|
@wraps(sdpa_pre_sage_atten)
|
|
def sdpa_sage_atten(query: torch.FloatTensor, key: torch.FloatTensor, value: torch.FloatTensor, attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, enable_gqa: bool = False, **kwargs) -> torch.FloatTensor:
|
|
if (query.shape[-1] in {128, 96, 64}) and (attn_mask is None) and (query.dtype != torch.float32):
|
|
if enable_gqa:
|
|
key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
|
|
value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)
|
|
|
|
# Call pre-selected sage attention implementation
|
|
return sage_attn_impl(query, key, value, is_causal, scale)
|
|
else:
|
|
if enable_gqa:
|
|
kwargs["enable_gqa"] = enable_gqa
|
|
return sdpa_pre_sage_atten(query=query, key=key, value=value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, scale=scale, **kwargs)
|
|
torch.nn.functional.scaled_dot_product_attention = sdpa_sage_atten
|
|
log.debug(f'Torch attention: type="Sage attention" backend={"cuda" if use_cuda_backend else "auto"}')
|
|
except Exception as err:
|
|
log.error(f'Torch attention: type="Sage attention" {err}')
|
|
|
|
|
|
def set_diffusers_attention(pipe, quiet:bool=False):
|
|
from modules import shared
|
|
import diffusers.models.attention_processor as p
|
|
|
|
def set_attn(pipe, attention, name:str=None):
|
|
if attention is None:
|
|
return
|
|
# other models uses their own attention processor
|
|
if getattr(pipe, "unet", None) is not None and hasattr(pipe.unet, "set_attn_processor"):
|
|
try:
|
|
pipe.unet.set_attn_processor(attention)
|
|
except Exception as e:
|
|
if 'Nunchaku' in pipe.unet.__class__.__name__:
|
|
pass
|
|
else:
|
|
shared.log.error(f'Torch attention: type="{name}" cls={attention.__class__.__name__} pipe={pipe.__class__.__name__} {e}')
|
|
""" # each transformer typically has its own attention processor
|
|
if getattr(pipe, "transformer", None) is not None and hasattr(pipe.transformer, "set_attn_processor"):
|
|
try:
|
|
pipe.transformer.set_attn_processor(attention)
|
|
except Exception as e:
|
|
if 'Nunchaku' in pipe.transformer.__class__.__name__:
|
|
pass
|
|
else:
|
|
shared.log.error(f'Torch attention: type="{name}" cls={attention.__class__.__name__} pipe={pipe.__class__.__name__} {e}')
|
|
"""
|
|
|
|
shared.log.quiet(quiet, f'Setting model: attention="{shared.opts.cross_attention_optimization}"')
|
|
if shared.opts.cross_attention_optimization == "Disabled":
|
|
pass # do nothing
|
|
elif shared.opts.cross_attention_optimization == "Scaled-Dot-Product": # The default set by Diffusers
|
|
# set_attn(pipe, p.AttnProcessor2_0(), name="Scaled-Dot-Product")
|
|
pass
|
|
elif shared.opts.cross_attention_optimization == "xFormers":
|
|
if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
|
|
pipe.enable_xformers_memory_efficient_attention()
|
|
else:
|
|
shared.log.warning(f"Attention: xFormers is not compatible with {pipe.__class__.__name__}")
|
|
elif shared.opts.cross_attention_optimization == "Batch matrix-matrix":
|
|
set_attn(pipe, p.AttnProcessor(), name="Batch matrix-matrix")
|
|
elif shared.opts.cross_attention_optimization == "Dynamic Attention BMM":
|
|
from modules.sd_hijack_dynamic_atten import DynamicAttnProcessorBMM
|
|
set_attn(pipe, DynamicAttnProcessorBMM(), name="Dynamic Attention BMM")
|
|
|
|
if shared.opts.attention_slicing != "Default" and hasattr(pipe, "enable_attention_slicing") and hasattr(pipe, "disable_attention_slicing"):
|
|
if shared.opts.attention_slicing:
|
|
pipe.enable_attention_slicing()
|
|
else:
|
|
pipe.disable_attention_slicing()
|
|
shared.log.debug(f"Torch attention: slicing={shared.opts.attention_slicing}")
|
|
|
|
pipe.current_attn_name = shared.opts.cross_attention_optimization
|