1
0
mirror of https://github.com/vladmandic/sdnext.git synced 2026-01-27 15:02:48 +03:00
Files
sdnext/scripts/consistory/attention_processor.py
Vladimir Mandic c4d9338d2e major refactoring of modules
Signed-off-by: Vladimir Mandic <mandic00@live.com>
2025-07-03 09:18:38 -04:00

288 lines
14 KiB
Python

# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Not a contribution
# Changes made by NVIDIA CORPORATION & AFFILIATES enabling ConsiStory or otherwise documented as NVIDIA-proprietary
# are not a contribution and subject to the license under the LICENSE file located at the root directory.
from typing import Callable, Optional
import torch
import torch.nn.functional as F
from diffusers.utils import USE_PEFT_BACKEND
from diffusers.models.attention_processor import Attention
from .consistory_utils import AnchorCache, FeatureInjector, QueryStore
class ConsistoryAttnStoreProcessor:
def __init__(self, attnstore, place_in_unet):
super().__init__()
self.attnstore = attnstore
self.place_in_unet = place_in_unet
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, record_attention=True, **kwargs):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
# only need to store attention maps during the Attend and Excite process
# if attention_probs.requires_grad:
if record_attention:
self.attnstore(attention_probs, is_cross, self.place_in_unet, attn.heads)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class ConsistoryExtendedAttnXFormersAttnProcessor:
r"""
Processor for implementing memory efficient attention using xFormers.
Args:
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
"""
def __init__(self, place_in_unet, attnstore, extended_attn_kwargs, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
self.t_range = extended_attn_kwargs.get('t_range', [])
self.extend_kv_unet_parts = extended_attn_kwargs.get('extend_kv_unet_parts', ['down', 'mid', 'up'])
self.place_in_unet = place_in_unet
self.curr_unet_part = self.place_in_unet.split('_')[0]
self.attnstore = attnstore
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
perform_extend_attn: bool = False,
query_store: Optional[QueryStore] = None,
feature_injector: Optional[FeatureInjector] = None,
anchors_cache: Optional[AnchorCache] = None,
**kwargs
) -> torch.FloatTensor:
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
else:
batch_size, wh, channel = hidden_states.shape
height = width = int(wh ** 0.5)
is_cross = encoder_hidden_states is not None
perform_extend_attn = perform_extend_attn and (not is_cross) and \
any([self.attnstore.curr_iter >= x[0] and self.attnstore.curr_iter <= x[1] for x in self.t_range]) and \
self.curr_unet_part in self.extend_kv_unet_parts
batch_size, key_tokens, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if (self.curr_unet_part in self.extend_kv_unet_parts) and query_store and query_store.mode == 'cache':
query_store.cache_query(query, self.place_in_unet)
elif perform_extend_attn and query_store and query_store.mode == 'inject':
query = query_store.inject_query(query, self.place_in_unet, self.attnstore.curr_iter)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query).contiguous()
if perform_extend_attn:
# Anchor Caching
if anchors_cache and anchors_cache.is_cache_mode():
if self.place_in_unet not in anchors_cache.input_h_cache:
anchors_cache.input_h_cache[self.place_in_unet] = {}
# Hidden states inside the mask, for uncond (index 0) and cond (index 1) prompts
subjects_hidden_states = torch.stack([x[self.attnstore.last_mask_dropout[width]] for x in hidden_states.chunk(2)])
anchors_cache.input_h_cache[self.place_in_unet][self.attnstore.curr_iter] = subjects_hidden_states
if anchors_cache and anchors_cache.is_inject_mode():
# We make extended key and value by concatenating the original key and value with the query.
anchors_hidden_states = anchors_cache.input_h_cache[self.place_in_unet][self.attnstore.curr_iter]
anchors_keys = attn.to_k(anchors_hidden_states, *args)
anchors_values = attn.to_v(anchors_hidden_states, *args)
extended_key = torch.cat([torch.cat([key.chunk(2, dim=0)[x], anchors_keys[x].unsqueeze(0)], dim=1) for x in range(2)])
extended_value = torch.cat([torch.cat([value.chunk(2, dim=0)[x], anchors_values[x].unsqueeze(0)], dim=1) for x in range(2)])
extended_key = attn.head_to_batch_dim(extended_key).contiguous()
extended_value = attn.head_to_batch_dim(extended_value).contiguous()
# attn_masks needs to be of shape [batch_size, query_tokens, key_tokens]
# hidden_states = xformers.ops.memory_efficient_attention(query, extended_key, extended_value, op=self.attention_op, scale=attn.scale)
hidden_states = F.scaled_dot_product_attention(query, extended_key, extended_value, scale=attn.scale)
else:
# # We make extended key and value by concatenating the original key and value with the query.
# attention_mask_bias = self.attnstore.get_attn_mask_bias(tgt_size = width, bsz = batch_size)
# if attention_mask_bias is not None:
# attention_mask_bias = torch.cat([x.unsqueeze(0).expand(attn.heads, -1, -1) for x in attention_mask_bias])
# Pre-allocate the output tensor
ex_out = torch.empty_like(query)
for i in range(batch_size):
start_idx = i * attn.heads
end_idx = start_idx + attn.heads
attention_mask = self.attnstore.get_extended_attn_mask_instance(width, i%(batch_size//2))
curr_q = query[start_idx:end_idx]
if i < batch_size//2:
curr_k = key[:batch_size//2]
curr_v = value[:batch_size//2]
else:
curr_k = key[batch_size//2:]
curr_v = value[batch_size//2:]
curr_k = curr_k.flatten(0,1)[attention_mask].unsqueeze(0)
curr_v = curr_v.flatten(0,1)[attention_mask].unsqueeze(0)
curr_k = attn.head_to_batch_dim(curr_k).contiguous()
curr_v = attn.head_to_batch_dim(curr_v).contiguous()
# hidden_states = xformers.ops.memory_efficient_attention(curr_q, curr_k, curr_v, op=self.attention_op, scale=attn.scale)
hidden_states = F.scaled_dot_product_attention(curr_q, curr_k, curr_v, scale=attn.scale)
ex_out[start_idx:end_idx] = hidden_states
hidden_states = ex_out
else:
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
# attn_masks needs to be of shape [batch_size, query_tokens, key_tokens]
# hidden_states = xformers.ops.memory_efficient_attention(query, key, value, op=self.attention_op, scale=attn.scale)
hidden_states = F.scaled_dot_product_attention(query, key, value, scale=attn.scale)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if feature_injector is not None:
output_res = int(hidden_states.shape[1] ** 0.5)
if anchors_cache and anchors_cache.is_inject_mode():
hidden_states[batch_size//2:] = feature_injector.inject_anchors(hidden_states[batch_size//2:], self.attnstore.curr_iter, output_res, self.attnstore.extended_mapping, self.place_in_unet, anchors_cache)
else:
hidden_states[batch_size//2:] = feature_injector.inject_outputs(hidden_states[batch_size//2:], self.attnstore.curr_iter, output_res, self.attnstore.extended_mapping, self.place_in_unet, anchors_cache)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def register_extended_self_attn(unet, attnstore, extended_attn_kwargs):
DICT_PLACE_TO_RES = {'down_0': 64, 'down_1': 64, 'down_2': 64, 'down_3': 64, 'down_4': 64, 'down_5': 64, 'down_6': 64, 'down_7': 64,
'down_8': 32, 'down_9': 32, 'down_10': 32, 'down_11': 32, 'down_12': 32, 'down_13': 32, 'down_14': 32, 'down_15': 32,
'down_16': 32, 'down_17': 32, 'down_18': 32, 'down_19': 32, 'down_20': 32, 'down_21': 32, 'down_22': 32, 'down_23': 32,
'down_24': 32, 'down_25': 32, 'down_26': 32, 'down_27': 32, 'down_28': 32, 'down_29': 32, 'down_30': 32, 'down_31': 32,
'down_32': 32, 'down_33': 32, 'down_34': 32, 'down_35': 32, 'down_36': 32, 'down_37': 32, 'down_38': 32, 'down_39': 32,
'down_40': 32, 'down_41': 32, 'down_42': 32, 'down_43': 32, 'down_44': 32, 'down_45': 32, 'down_46': 32, 'down_47': 32,
'mid_120': 32, 'mid_121': 32, 'mid_122': 32, 'mid_123': 32, 'mid_124': 32, 'mid_125': 32, 'mid_126': 32, 'mid_127': 32,
'mid_128': 32, 'mid_129': 32, 'mid_130': 32, 'mid_131': 32, 'mid_132': 32, 'mid_133': 32, 'mid_134': 32, 'mid_135': 32,
'mid_136': 32, 'mid_137': 32, 'mid_138': 32, 'mid_139': 32, 'up_49': 32, 'up_51': 32, 'up_53': 32, 'up_55': 32, 'up_57': 32,
'up_59': 32, 'up_61': 32, 'up_63': 32, 'up_65': 32, 'up_67': 32, 'up_69': 32, 'up_71': 32, 'up_73': 32, 'up_75': 32,
'up_77': 32, 'up_79': 32, 'up_81': 32, 'up_83': 32, 'up_85': 32, 'up_87': 32, 'up_89': 32, 'up_91': 32, 'up_93': 32,
'up_95': 32, 'up_97': 32, 'up_99': 32, 'up_101': 32, 'up_103': 32, 'up_105': 32, 'up_107': 32, 'up_109': 64, 'up_111': 64,
'up_113': 64, 'up_115': 64, 'up_117': 64, 'up_119': 64}
attn_procs = {}
for i, name in enumerate(unet.attn_processors.keys()):
is_self_attn = i % 2 == 0
if name.startswith("mid_block"):
place_in_unet = f"mid_{i}"
elif name.startswith("up_blocks"):
place_in_unet = f"up_{i}"
elif name.startswith("down_blocks"):
place_in_unet = f"down_{i}"
else:
continue
if is_self_attn:
attn_procs[name] = ConsistoryExtendedAttnXFormersAttnProcessor(place_in_unet, attnstore, extended_attn_kwargs)
else:
attn_procs[name] = ConsistoryAttnStoreProcessor(attnstore, place_in_unet)
unet.set_attn_processor(attn_procs)