mirror of
https://github.com/vladmandic/sdnext.git
synced 2026-01-27 15:02:48 +03:00
288 lines
14 KiB
Python
288 lines
14 KiB
Python
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Not a contribution
|
|
# Changes made by NVIDIA CORPORATION & AFFILIATES enabling ConsiStory or otherwise documented as NVIDIA-proprietary
|
|
# are not a contribution and subject to the license under the LICENSE file located at the root directory.
|
|
|
|
|
|
from typing import Callable, Optional
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from diffusers.utils import USE_PEFT_BACKEND
|
|
from diffusers.models.attention_processor import Attention
|
|
from .consistory_utils import AnchorCache, FeatureInjector, QueryStore
|
|
|
|
|
|
class ConsistoryAttnStoreProcessor:
|
|
def __init__(self, attnstore, place_in_unet):
|
|
super().__init__()
|
|
self.attnstore = attnstore
|
|
self.place_in_unet = place_in_unet
|
|
|
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, record_attention=True, **kwargs):
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
is_cross = encoder_hidden_states is not None
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
query = attn.head_to_batch_dim(query)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
|
|
# only need to store attention maps during the Attend and Excite process
|
|
# if attention_probs.requires_grad:
|
|
if record_attention:
|
|
self.attnstore(attention_probs, is_cross, self.place_in_unet, attn.heads)
|
|
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class ConsistoryExtendedAttnXFormersAttnProcessor:
|
|
r"""
|
|
Processor for implementing memory efficient attention using xFormers.
|
|
|
|
Args:
|
|
attention_op (`Callable`, *optional*, defaults to `None`):
|
|
The base
|
|
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
|
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
|
operator.
|
|
"""
|
|
|
|
def __init__(self, place_in_unet, attnstore, extended_attn_kwargs, attention_op: Optional[Callable] = None):
|
|
self.attention_op = attention_op
|
|
self.t_range = extended_attn_kwargs.get('t_range', [])
|
|
self.extend_kv_unet_parts = extended_attn_kwargs.get('extend_kv_unet_parts', ['down', 'mid', 'up'])
|
|
|
|
self.place_in_unet = place_in_unet
|
|
self.curr_unet_part = self.place_in_unet.split('_')[0]
|
|
self.attnstore = attnstore
|
|
|
|
def __call__(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states: torch.FloatTensor,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
temb: Optional[torch.FloatTensor] = None,
|
|
scale: float = 1.0,
|
|
perform_extend_attn: bool = False,
|
|
query_store: Optional[QueryStore] = None,
|
|
feature_injector: Optional[FeatureInjector] = None,
|
|
anchors_cache: Optional[AnchorCache] = None,
|
|
**kwargs
|
|
) -> torch.FloatTensor:
|
|
residual = hidden_states
|
|
|
|
args = () if USE_PEFT_BACKEND else (scale,)
|
|
|
|
if attn.spatial_norm is not None:
|
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
else:
|
|
batch_size, wh, channel = hidden_states.shape
|
|
height = width = int(wh ** 0.5)
|
|
|
|
is_cross = encoder_hidden_states is not None
|
|
perform_extend_attn = perform_extend_attn and (not is_cross) and \
|
|
any([self.attnstore.curr_iter >= x[0] and self.attnstore.curr_iter <= x[1] for x in self.t_range]) and \
|
|
self.curr_unet_part in self.extend_kv_unet_parts
|
|
|
|
batch_size, key_tokens, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
|
if attention_mask is not None:
|
|
# expand our mask's singleton query_tokens dimension:
|
|
# [batch*heads, 1, key_tokens] ->
|
|
# [batch*heads, query_tokens, key_tokens]
|
|
# so that it can be added as a bias onto the attention scores that xformers computes:
|
|
# [batch*heads, query_tokens, key_tokens]
|
|
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
|
_, query_tokens, _ = hidden_states.shape
|
|
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states, *args)
|
|
|
|
if (self.curr_unet_part in self.extend_kv_unet_parts) and query_store and query_store.mode == 'cache':
|
|
query_store.cache_query(query, self.place_in_unet)
|
|
elif perform_extend_attn and query_store and query_store.mode == 'inject':
|
|
query = query_store.inject_query(query, self.place_in_unet, self.attnstore.curr_iter)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states, *args)
|
|
value = attn.to_v(encoder_hidden_states, *args)
|
|
|
|
query = attn.head_to_batch_dim(query).contiguous()
|
|
|
|
if perform_extend_attn:
|
|
# Anchor Caching
|
|
if anchors_cache and anchors_cache.is_cache_mode():
|
|
if self.place_in_unet not in anchors_cache.input_h_cache:
|
|
anchors_cache.input_h_cache[self.place_in_unet] = {}
|
|
|
|
# Hidden states inside the mask, for uncond (index 0) and cond (index 1) prompts
|
|
subjects_hidden_states = torch.stack([x[self.attnstore.last_mask_dropout[width]] for x in hidden_states.chunk(2)])
|
|
anchors_cache.input_h_cache[self.place_in_unet][self.attnstore.curr_iter] = subjects_hidden_states
|
|
|
|
if anchors_cache and anchors_cache.is_inject_mode():
|
|
# We make extended key and value by concatenating the original key and value with the query.
|
|
anchors_hidden_states = anchors_cache.input_h_cache[self.place_in_unet][self.attnstore.curr_iter]
|
|
|
|
anchors_keys = attn.to_k(anchors_hidden_states, *args)
|
|
anchors_values = attn.to_v(anchors_hidden_states, *args)
|
|
|
|
extended_key = torch.cat([torch.cat([key.chunk(2, dim=0)[x], anchors_keys[x].unsqueeze(0)], dim=1) for x in range(2)])
|
|
extended_value = torch.cat([torch.cat([value.chunk(2, dim=0)[x], anchors_values[x].unsqueeze(0)], dim=1) for x in range(2)])
|
|
|
|
extended_key = attn.head_to_batch_dim(extended_key).contiguous()
|
|
extended_value = attn.head_to_batch_dim(extended_value).contiguous()
|
|
|
|
# attn_masks needs to be of shape [batch_size, query_tokens, key_tokens]
|
|
# hidden_states = xformers.ops.memory_efficient_attention(query, extended_key, extended_value, op=self.attention_op, scale=attn.scale)
|
|
hidden_states = F.scaled_dot_product_attention(query, extended_key, extended_value, scale=attn.scale)
|
|
else:
|
|
# # We make extended key and value by concatenating the original key and value with the query.
|
|
# attention_mask_bias = self.attnstore.get_attn_mask_bias(tgt_size = width, bsz = batch_size)
|
|
|
|
# if attention_mask_bias is not None:
|
|
# attention_mask_bias = torch.cat([x.unsqueeze(0).expand(attn.heads, -1, -1) for x in attention_mask_bias])
|
|
|
|
# Pre-allocate the output tensor
|
|
ex_out = torch.empty_like(query)
|
|
|
|
for i in range(batch_size):
|
|
start_idx = i * attn.heads
|
|
end_idx = start_idx + attn.heads
|
|
|
|
attention_mask = self.attnstore.get_extended_attn_mask_instance(width, i%(batch_size//2))
|
|
|
|
curr_q = query[start_idx:end_idx]
|
|
|
|
if i < batch_size//2:
|
|
curr_k = key[:batch_size//2]
|
|
curr_v = value[:batch_size//2]
|
|
else:
|
|
curr_k = key[batch_size//2:]
|
|
curr_v = value[batch_size//2:]
|
|
|
|
curr_k = curr_k.flatten(0,1)[attention_mask].unsqueeze(0)
|
|
curr_v = curr_v.flatten(0,1)[attention_mask].unsqueeze(0)
|
|
|
|
curr_k = attn.head_to_batch_dim(curr_k).contiguous()
|
|
curr_v = attn.head_to_batch_dim(curr_v).contiguous()
|
|
|
|
# hidden_states = xformers.ops.memory_efficient_attention(curr_q, curr_k, curr_v, op=self.attention_op, scale=attn.scale)
|
|
hidden_states = F.scaled_dot_product_attention(curr_q, curr_k, curr_v, scale=attn.scale)
|
|
|
|
ex_out[start_idx:end_idx] = hidden_states
|
|
|
|
hidden_states = ex_out
|
|
else:
|
|
key = attn.head_to_batch_dim(key).contiguous()
|
|
value = attn.head_to_batch_dim(value).contiguous()
|
|
|
|
# attn_masks needs to be of shape [batch_size, query_tokens, key_tokens]
|
|
# hidden_states = xformers.ops.memory_efficient_attention(query, key, value, op=self.attention_op, scale=attn.scale)
|
|
hidden_states = F.scaled_dot_product_attention(query, key, value, scale=attn.scale)
|
|
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states, *args)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if feature_injector is not None:
|
|
output_res = int(hidden_states.shape[1] ** 0.5)
|
|
|
|
if anchors_cache and anchors_cache.is_inject_mode():
|
|
hidden_states[batch_size//2:] = feature_injector.inject_anchors(hidden_states[batch_size//2:], self.attnstore.curr_iter, output_res, self.attnstore.extended_mapping, self.place_in_unet, anchors_cache)
|
|
else:
|
|
hidden_states[batch_size//2:] = feature_injector.inject_outputs(hidden_states[batch_size//2:], self.attnstore.curr_iter, output_res, self.attnstore.extended_mapping, self.place_in_unet, anchors_cache)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
def register_extended_self_attn(unet, attnstore, extended_attn_kwargs):
|
|
DICT_PLACE_TO_RES = {'down_0': 64, 'down_1': 64, 'down_2': 64, 'down_3': 64, 'down_4': 64, 'down_5': 64, 'down_6': 64, 'down_7': 64,
|
|
'down_8': 32, 'down_9': 32, 'down_10': 32, 'down_11': 32, 'down_12': 32, 'down_13': 32, 'down_14': 32, 'down_15': 32,
|
|
'down_16': 32, 'down_17': 32, 'down_18': 32, 'down_19': 32, 'down_20': 32, 'down_21': 32, 'down_22': 32, 'down_23': 32,
|
|
'down_24': 32, 'down_25': 32, 'down_26': 32, 'down_27': 32, 'down_28': 32, 'down_29': 32, 'down_30': 32, 'down_31': 32,
|
|
'down_32': 32, 'down_33': 32, 'down_34': 32, 'down_35': 32, 'down_36': 32, 'down_37': 32, 'down_38': 32, 'down_39': 32,
|
|
'down_40': 32, 'down_41': 32, 'down_42': 32, 'down_43': 32, 'down_44': 32, 'down_45': 32, 'down_46': 32, 'down_47': 32,
|
|
'mid_120': 32, 'mid_121': 32, 'mid_122': 32, 'mid_123': 32, 'mid_124': 32, 'mid_125': 32, 'mid_126': 32, 'mid_127': 32,
|
|
'mid_128': 32, 'mid_129': 32, 'mid_130': 32, 'mid_131': 32, 'mid_132': 32, 'mid_133': 32, 'mid_134': 32, 'mid_135': 32,
|
|
'mid_136': 32, 'mid_137': 32, 'mid_138': 32, 'mid_139': 32, 'up_49': 32, 'up_51': 32, 'up_53': 32, 'up_55': 32, 'up_57': 32,
|
|
'up_59': 32, 'up_61': 32, 'up_63': 32, 'up_65': 32, 'up_67': 32, 'up_69': 32, 'up_71': 32, 'up_73': 32, 'up_75': 32,
|
|
'up_77': 32, 'up_79': 32, 'up_81': 32, 'up_83': 32, 'up_85': 32, 'up_87': 32, 'up_89': 32, 'up_91': 32, 'up_93': 32,
|
|
'up_95': 32, 'up_97': 32, 'up_99': 32, 'up_101': 32, 'up_103': 32, 'up_105': 32, 'up_107': 32, 'up_109': 64, 'up_111': 64,
|
|
'up_113': 64, 'up_115': 64, 'up_117': 64, 'up_119': 64}
|
|
attn_procs = {}
|
|
for i, name in enumerate(unet.attn_processors.keys()):
|
|
is_self_attn = i % 2 == 0
|
|
if name.startswith("mid_block"):
|
|
place_in_unet = f"mid_{i}"
|
|
elif name.startswith("up_blocks"):
|
|
place_in_unet = f"up_{i}"
|
|
elif name.startswith("down_blocks"):
|
|
place_in_unet = f"down_{i}"
|
|
else:
|
|
continue
|
|
|
|
if is_self_attn:
|
|
attn_procs[name] = ConsistoryExtendedAttnXFormersAttnProcessor(place_in_unet, attnstore, extended_attn_kwargs)
|
|
else:
|
|
attn_procs[name] = ConsistoryAttnStoreProcessor(attnstore, place_in_unet)
|
|
|
|
unet.set_attn_processor(attn_procs)
|