mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add kandinsky2.1 --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Ayush Mangal <43698245+ayushtues@users.noreply.github.com> Co-authored-by: ayushmangal <ayushmangal@microsoft.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1317 lines
55 KiB
Python
1317 lines
55 KiB
Python
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import warnings
|
|
from typing import Callable, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from ..utils import deprecate, logging, maybe_allow_in_graph
|
|
from ..utils.import_utils import is_xformers_available
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
if is_xformers_available():
|
|
import xformers
|
|
import xformers.ops
|
|
else:
|
|
xformers = None
|
|
|
|
|
|
@maybe_allow_in_graph
|
|
class Attention(nn.Module):
|
|
r"""
|
|
A cross attention layer.
|
|
|
|
Parameters:
|
|
query_dim (`int`): The number of channels in the query.
|
|
cross_attention_dim (`int`, *optional*):
|
|
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
|
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
|
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
bias (`bool`, *optional*, defaults to False):
|
|
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
query_dim: int,
|
|
cross_attention_dim: Optional[int] = None,
|
|
heads: int = 8,
|
|
dim_head: int = 64,
|
|
dropout: float = 0.0,
|
|
bias=False,
|
|
upcast_attention: bool = False,
|
|
upcast_softmax: bool = False,
|
|
cross_attention_norm: Optional[str] = None,
|
|
cross_attention_norm_num_groups: int = 32,
|
|
added_kv_proj_dim: Optional[int] = None,
|
|
norm_num_groups: Optional[int] = None,
|
|
spatial_norm_dim: Optional[int] = None,
|
|
out_bias: bool = True,
|
|
scale_qk: bool = True,
|
|
only_cross_attention: bool = False,
|
|
eps: float = 1e-5,
|
|
rescale_output_factor: float = 1.0,
|
|
residual_connection: bool = False,
|
|
_from_deprecated_attn_block=False,
|
|
processor: Optional["AttnProcessor"] = None,
|
|
):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
|
self.upcast_attention = upcast_attention
|
|
self.upcast_softmax = upcast_softmax
|
|
self.rescale_output_factor = rescale_output_factor
|
|
self.residual_connection = residual_connection
|
|
|
|
# we make use of this private variable to know whether this class is loaded
|
|
# with an deprecated state dict so that we can convert it on the fly
|
|
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
|
|
|
self.scale_qk = scale_qk
|
|
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
|
|
|
self.heads = heads
|
|
# for slice_size > 0 the attention score computation
|
|
# is split across the batch axis to save memory
|
|
# You can set slice_size with `set_attention_slice`
|
|
self.sliceable_head_dim = heads
|
|
|
|
self.added_kv_proj_dim = added_kv_proj_dim
|
|
self.only_cross_attention = only_cross_attention
|
|
|
|
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
|
raise ValueError(
|
|
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
|
)
|
|
|
|
if norm_num_groups is not None:
|
|
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
|
|
else:
|
|
self.group_norm = None
|
|
|
|
if spatial_norm_dim is not None:
|
|
self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
|
|
else:
|
|
self.spatial_norm = None
|
|
|
|
if cross_attention_norm is None:
|
|
self.norm_cross = None
|
|
elif cross_attention_norm == "layer_norm":
|
|
self.norm_cross = nn.LayerNorm(cross_attention_dim)
|
|
elif cross_attention_norm == "group_norm":
|
|
if self.added_kv_proj_dim is not None:
|
|
# The given `encoder_hidden_states` are initially of shape
|
|
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
|
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
|
# before the projection, so we need to use `added_kv_proj_dim` as
|
|
# the number of channels for the group norm.
|
|
norm_cross_num_channels = added_kv_proj_dim
|
|
else:
|
|
norm_cross_num_channels = cross_attention_dim
|
|
|
|
self.norm_cross = nn.GroupNorm(
|
|
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
|
)
|
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
|
|
|
if not self.only_cross_attention:
|
|
# only relevant for the `AddedKVProcessor` classes
|
|
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
|
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
|
else:
|
|
self.to_k = None
|
|
self.to_v = None
|
|
|
|
if self.added_kv_proj_dim is not None:
|
|
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
|
self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
|
|
|
self.to_out = nn.ModuleList([])
|
|
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
|
|
self.to_out.append(nn.Dropout(dropout))
|
|
|
|
# set attention processor
|
|
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
|
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
|
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
|
if processor is None:
|
|
processor = (
|
|
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
|
)
|
|
self.set_processor(processor)
|
|
|
|
def set_use_memory_efficient_attention_xformers(
|
|
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
|
):
|
|
is_lora = hasattr(self, "processor") and isinstance(
|
|
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
|
|
)
|
|
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
|
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
|
)
|
|
|
|
if use_memory_efficient_attention_xformers:
|
|
if self.added_kv_proj_dim is not None:
|
|
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
|
# which uses this type of cross attention ONLY because the attention mask of format
|
|
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
|
raise NotImplementedError(
|
|
"Memory efficient attention with `xformers` is currently not supported when"
|
|
" `self.added_kv_proj_dim` is defined."
|
|
)
|
|
elif not is_xformers_available():
|
|
raise ModuleNotFoundError(
|
|
(
|
|
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
|
" xformers"
|
|
),
|
|
name="xformers",
|
|
)
|
|
elif not torch.cuda.is_available():
|
|
raise ValueError(
|
|
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
|
" only available for GPU "
|
|
)
|
|
elif hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
|
|
warnings.warn(
|
|
"You have specified using flash attention using xFormers but you have PyTorch 2.0 already installed. "
|
|
"We will default to PyTorch's native efficient flash attention implementation (`F.scaled_dot_product_attention`) "
|
|
"introduced in PyTorch 2.0. In case you are using LoRA or Custom Diffusion, we will fall "
|
|
"back to their respective attention processors i.e., we will NOT use the PyTorch 2.0 "
|
|
"native efficient flash attention."
|
|
)
|
|
else:
|
|
try:
|
|
# Make sure we can run the memory efficient attention
|
|
_ = xformers.ops.memory_efficient_attention(
|
|
torch.randn((1, 2, 40), device="cuda"),
|
|
torch.randn((1, 2, 40), device="cuda"),
|
|
torch.randn((1, 2, 40), device="cuda"),
|
|
)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
if is_lora:
|
|
processor = LoRAXFormersAttnProcessor(
|
|
hidden_size=self.processor.hidden_size,
|
|
cross_attention_dim=self.processor.cross_attention_dim,
|
|
rank=self.processor.rank,
|
|
attention_op=attention_op,
|
|
)
|
|
processor.load_state_dict(self.processor.state_dict())
|
|
processor.to(self.processor.to_q_lora.up.weight.device)
|
|
print(
|
|
f"is_lora is set to {is_lora}, type: LoRAXFormersAttnProcessor: {isinstance(processor, LoRAXFormersAttnProcessor)}"
|
|
)
|
|
elif is_custom_diffusion:
|
|
processor = CustomDiffusionXFormersAttnProcessor(
|
|
train_kv=self.processor.train_kv,
|
|
train_q_out=self.processor.train_q_out,
|
|
hidden_size=self.processor.hidden_size,
|
|
cross_attention_dim=self.processor.cross_attention_dim,
|
|
attention_op=attention_op,
|
|
)
|
|
processor.load_state_dict(self.processor.state_dict())
|
|
if hasattr(self.processor, "to_k_custom_diffusion"):
|
|
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
|
else:
|
|
processor = XFormersAttnProcessor(attention_op=attention_op)
|
|
else:
|
|
if is_lora:
|
|
processor = LoRAAttnProcessor(
|
|
hidden_size=self.processor.hidden_size,
|
|
cross_attention_dim=self.processor.cross_attention_dim,
|
|
rank=self.processor.rank,
|
|
)
|
|
processor.load_state_dict(self.processor.state_dict())
|
|
processor.to(self.processor.to_q_lora.up.weight.device)
|
|
elif is_custom_diffusion:
|
|
processor = CustomDiffusionAttnProcessor(
|
|
train_kv=self.processor.train_kv,
|
|
train_q_out=self.processor.train_q_out,
|
|
hidden_size=self.processor.hidden_size,
|
|
cross_attention_dim=self.processor.cross_attention_dim,
|
|
)
|
|
processor.load_state_dict(self.processor.state_dict())
|
|
if hasattr(self.processor, "to_k_custom_diffusion"):
|
|
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
|
else:
|
|
# set attention processor
|
|
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
|
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
|
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
|
print("Still defaulting to: AttnProcessor2_0 :O")
|
|
processor = (
|
|
AttnProcessor2_0()
|
|
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk
|
|
else AttnProcessor()
|
|
)
|
|
|
|
self.set_processor(processor)
|
|
|
|
def set_attention_slice(self, slice_size):
|
|
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
|
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
|
|
|
if slice_size is not None and self.added_kv_proj_dim is not None:
|
|
processor = SlicedAttnAddedKVProcessor(slice_size)
|
|
elif slice_size is not None:
|
|
processor = SlicedAttnProcessor(slice_size)
|
|
elif self.added_kv_proj_dim is not None:
|
|
processor = AttnAddedKVProcessor()
|
|
else:
|
|
# set attention processor
|
|
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
|
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
|
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
|
processor = (
|
|
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
|
)
|
|
|
|
self.set_processor(processor)
|
|
|
|
def set_processor(self, processor: "AttnProcessor"):
|
|
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
|
# pop `processor` from `self._modules`
|
|
if (
|
|
hasattr(self, "processor")
|
|
and isinstance(self.processor, torch.nn.Module)
|
|
and not isinstance(processor, torch.nn.Module)
|
|
):
|
|
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
|
self._modules.pop("processor")
|
|
|
|
self.processor = processor
|
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
|
# The `Attention` class can call different attention processors / attention functions
|
|
# here we simply pass along all tensors to the selected processor class
|
|
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
|
return self.processor(
|
|
self,
|
|
hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=attention_mask,
|
|
**cross_attention_kwargs,
|
|
)
|
|
|
|
def batch_to_head_dim(self, tensor):
|
|
head_size = self.heads
|
|
batch_size, seq_len, dim = tensor.shape
|
|
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
|
return tensor
|
|
|
|
def head_to_batch_dim(self, tensor, out_dim=3):
|
|
head_size = self.heads
|
|
batch_size, seq_len, dim = tensor.shape
|
|
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
|
tensor = tensor.permute(0, 2, 1, 3)
|
|
|
|
if out_dim == 3:
|
|
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
|
|
|
return tensor
|
|
|
|
def get_attention_scores(self, query, key, attention_mask=None):
|
|
dtype = query.dtype
|
|
if self.upcast_attention:
|
|
query = query.float()
|
|
key = key.float()
|
|
|
|
if attention_mask is None:
|
|
baddbmm_input = torch.empty(
|
|
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
|
)
|
|
beta = 0
|
|
else:
|
|
baddbmm_input = attention_mask
|
|
beta = 1
|
|
|
|
attention_scores = torch.baddbmm(
|
|
baddbmm_input,
|
|
query,
|
|
key.transpose(-1, -2),
|
|
beta=beta,
|
|
alpha=self.scale,
|
|
)
|
|
del baddbmm_input
|
|
|
|
if self.upcast_softmax:
|
|
attention_scores = attention_scores.float()
|
|
|
|
attention_probs = attention_scores.softmax(dim=-1)
|
|
del attention_scores
|
|
|
|
attention_probs = attention_probs.to(dtype)
|
|
|
|
return attention_probs
|
|
|
|
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
|
|
if batch_size is None:
|
|
deprecate(
|
|
"batch_size=None",
|
|
"0.0.15",
|
|
(
|
|
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
|
|
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
|
|
" `prepare_attention_mask` when preparing the attention_mask."
|
|
),
|
|
)
|
|
batch_size = 1
|
|
|
|
head_size = self.heads
|
|
if attention_mask is None:
|
|
return attention_mask
|
|
|
|
current_length: int = attention_mask.shape[-1]
|
|
if current_length != target_length:
|
|
if attention_mask.device.type == "mps":
|
|
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
|
# Instead, we can manually construct the padding tensor.
|
|
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
|
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
|
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
|
else:
|
|
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
|
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
|
# remaining_length: int = target_length - current_length
|
|
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
|
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
|
|
|
if out_dim == 3:
|
|
if attention_mask.shape[0] < batch_size * head_size:
|
|
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
|
elif out_dim == 4:
|
|
attention_mask = attention_mask.unsqueeze(1)
|
|
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
|
|
|
return attention_mask
|
|
|
|
def norm_encoder_hidden_states(self, encoder_hidden_states):
|
|
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
|
|
|
if isinstance(self.norm_cross, nn.LayerNorm):
|
|
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
|
elif isinstance(self.norm_cross, nn.GroupNorm):
|
|
# Group norm norms along the channels dimension and expects
|
|
# input to be in the shape of (N, C, *). In this case, we want
|
|
# to norm along the hidden dimension, so we need to move
|
|
# (batch_size, sequence_length, hidden_size) ->
|
|
# (batch_size, hidden_size, sequence_length)
|
|
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
|
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
|
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
|
else:
|
|
assert False
|
|
|
|
return encoder_hidden_states
|
|
|
|
|
|
class AttnProcessor:
|
|
def __call__(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
temb=None,
|
|
):
|
|
residual = hidden_states
|
|
|
|
if attn.spatial_norm is not None:
|
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
query = attn.head_to_batch_dim(query)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
class LoRALinearLayer(nn.Module):
|
|
def __init__(self, in_features, out_features, rank=4):
|
|
super().__init__()
|
|
|
|
if rank > min(in_features, out_features):
|
|
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
|
|
|
self.down = nn.Linear(in_features, rank, bias=False)
|
|
self.up = nn.Linear(rank, out_features, bias=False)
|
|
|
|
nn.init.normal_(self.down.weight, std=1 / rank)
|
|
nn.init.zeros_(self.up.weight)
|
|
|
|
def forward(self, hidden_states):
|
|
orig_dtype = hidden_states.dtype
|
|
dtype = self.down.weight.dtype
|
|
|
|
down_hidden_states = self.down(hidden_states.to(dtype))
|
|
up_hidden_states = self.up(down_hidden_states)
|
|
|
|
return up_hidden_states.to(orig_dtype)
|
|
|
|
|
|
class LoRAAttnProcessor(nn.Module):
|
|
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
|
super().__init__()
|
|
|
|
self.hidden_size = hidden_size
|
|
self.cross_attention_dim = cross_attention_dim
|
|
self.rank = rank
|
|
|
|
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
|
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
|
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
|
|
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
|
residual = hidden_states
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
|
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
class CustomDiffusionAttnProcessor(nn.Module):
|
|
def __init__(
|
|
self,
|
|
train_kv=True,
|
|
train_q_out=True,
|
|
hidden_size=None,
|
|
cross_attention_dim=None,
|
|
out_bias=True,
|
|
dropout=0.0,
|
|
):
|
|
super().__init__()
|
|
self.train_kv = train_kv
|
|
self.train_q_out = train_q_out
|
|
|
|
self.hidden_size = hidden_size
|
|
self.cross_attention_dim = cross_attention_dim
|
|
|
|
# `_custom_diffusion` id for easy serialization and loading.
|
|
if self.train_kv:
|
|
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
|
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
|
if self.train_q_out:
|
|
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
|
self.to_out_custom_diffusion = nn.ModuleList([])
|
|
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
|
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
|
|
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
if self.train_q_out:
|
|
query = self.to_q_custom_diffusion(hidden_states)
|
|
else:
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
crossattn = False
|
|
encoder_hidden_states = hidden_states
|
|
else:
|
|
crossattn = True
|
|
if attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
if self.train_kv:
|
|
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
|
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
|
else:
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
if crossattn:
|
|
detach = torch.ones_like(key)
|
|
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
|
key = detach * key + (1 - detach) * key.detach()
|
|
value = detach * value + (1 - detach) * value.detach()
|
|
|
|
query = attn.head_to_batch_dim(query)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
if self.train_q_out:
|
|
# linear proj
|
|
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
|
# dropout
|
|
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
|
else:
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class AttnAddedKVProcessor:
|
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
residual = hidden_states
|
|
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
|
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
|
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
|
|
|
if not attn.only_cross_attention:
|
|
key = attn.to_k(hidden_states)
|
|
value = attn.to_v(hidden_states)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
|
else:
|
|
key = encoder_hidden_states_key_proj
|
|
value = encoder_hidden_states_value_proj
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
|
hidden_states = hidden_states + residual
|
|
|
|
return hidden_states
|
|
|
|
|
|
class AttnAddedKVProcessor2_0:
|
|
def __init__(self):
|
|
if not hasattr(F, "scaled_dot_product_attention"):
|
|
raise ImportError(
|
|
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
|
)
|
|
|
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
residual = hidden_states
|
|
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
query = attn.head_to_batch_dim(query, out_dim=4)
|
|
|
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
|
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
|
|
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
|
|
|
if not attn.only_cross_attention:
|
|
key = attn.to_k(hidden_states)
|
|
value = attn.to_v(hidden_states)
|
|
key = attn.head_to_batch_dim(key, out_dim=4)
|
|
value = attn.head_to_batch_dim(value, out_dim=4)
|
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
|
else:
|
|
key = encoder_hidden_states_key_proj
|
|
value = encoder_hidden_states_value_proj
|
|
|
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
hidden_states = F.scaled_dot_product_attention(
|
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
)
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
|
hidden_states = hidden_states + residual
|
|
|
|
return hidden_states
|
|
|
|
|
|
class LoRAAttnAddedKVProcessor(nn.Module):
|
|
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
|
super().__init__()
|
|
|
|
self.hidden_size = hidden_size
|
|
self.cross_attention_dim = cross_attention_dim
|
|
self.rank = rank
|
|
|
|
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
|
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
|
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
|
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
|
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
|
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
|
|
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
|
residual = hidden_states
|
|
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora(
|
|
encoder_hidden_states
|
|
)
|
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora(
|
|
encoder_hidden_states
|
|
)
|
|
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
|
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
|
|
|
if not attn.only_cross_attention:
|
|
key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states)
|
|
value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
|
else:
|
|
key = encoder_hidden_states_key_proj
|
|
value = encoder_hidden_states_value_proj
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
|
hidden_states = hidden_states + residual
|
|
|
|
return hidden_states
|
|
|
|
|
|
class XFormersAttnProcessor:
|
|
def __init__(self, attention_op: Optional[Callable] = None):
|
|
self.attention_op = attention_op
|
|
|
|
def __call__(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states: torch.FloatTensor,
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
attention_mask: Optional[torch.FloatTensor] = None,
|
|
):
|
|
residual = hidden_states
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, key_tokens, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
|
if attention_mask is not None:
|
|
# expand our mask's singleton query_tokens dimension:
|
|
# [batch*heads, 1, key_tokens] ->
|
|
# [batch*heads, query_tokens, key_tokens]
|
|
# so that it can be added as a bias onto the attention scores that xformers computes:
|
|
# [batch*heads, query_tokens, key_tokens]
|
|
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
|
_, query_tokens, _ = hidden_states.shape
|
|
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
query = attn.head_to_batch_dim(query).contiguous()
|
|
key = attn.head_to_batch_dim(key).contiguous()
|
|
value = attn.head_to_batch_dim(value).contiguous()
|
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(
|
|
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
|
)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
class AttnProcessor2_0:
|
|
def __init__(self):
|
|
if not hasattr(F, "scaled_dot_product_attention"):
|
|
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
|
|
def __call__(
|
|
self,
|
|
attn: Attention,
|
|
hidden_states,
|
|
encoder_hidden_states=None,
|
|
attention_mask=None,
|
|
temb=None,
|
|
):
|
|
residual = hidden_states
|
|
|
|
if attn.spatial_norm is not None:
|
|
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
inner_dim = hidden_states.shape[-1]
|
|
|
|
if attention_mask is not None:
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
# scaled_dot_product_attention expects attention_mask shape to be
|
|
# (batch, heads, source_length, target_length)
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
head_dim = inner_dim // attn.heads
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
|
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
hidden_states = F.scaled_dot_product_attention(
|
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
)
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
class LoRAXFormersAttnProcessor(nn.Module):
|
|
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
|
|
super().__init__()
|
|
|
|
self.hidden_size = hidden_size
|
|
self.cross_attention_dim = cross_attention_dim
|
|
self.rank = rank
|
|
self.attention_op = attention_op
|
|
|
|
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
|
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
|
|
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
|
|
|
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
|
residual = hidden_states
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
|
query = attn.head_to_batch_dim(query).contiguous()
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
|
|
|
key = attn.head_to_batch_dim(key).contiguous()
|
|
value = attn.head_to_batch_dim(value).contiguous()
|
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(
|
|
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
|
)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
|
def __init__(
|
|
self,
|
|
train_kv=True,
|
|
train_q_out=False,
|
|
hidden_size=None,
|
|
cross_attention_dim=None,
|
|
out_bias=True,
|
|
dropout=0.0,
|
|
attention_op: Optional[Callable] = None,
|
|
):
|
|
super().__init__()
|
|
self.train_kv = train_kv
|
|
self.train_q_out = train_q_out
|
|
|
|
self.hidden_size = hidden_size
|
|
self.cross_attention_dim = cross_attention_dim
|
|
self.attention_op = attention_op
|
|
|
|
# `_custom_diffusion` id for easy serialization and loading.
|
|
if self.train_kv:
|
|
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
|
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
|
if self.train_q_out:
|
|
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
|
self.to_out_custom_diffusion = nn.ModuleList([])
|
|
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
|
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
|
|
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
if self.train_q_out:
|
|
query = self.to_q_custom_diffusion(hidden_states)
|
|
else:
|
|
query = attn.to_q(hidden_states)
|
|
|
|
if encoder_hidden_states is None:
|
|
crossattn = False
|
|
encoder_hidden_states = hidden_states
|
|
else:
|
|
crossattn = True
|
|
if attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
if self.train_kv:
|
|
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
|
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
|
else:
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
if crossattn:
|
|
detach = torch.ones_like(key)
|
|
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
|
key = detach * key + (1 - detach) * key.detach()
|
|
value = detach * value + (1 - detach) * value.detach()
|
|
|
|
query = attn.head_to_batch_dim(query).contiguous()
|
|
key = attn.head_to_batch_dim(key).contiguous()
|
|
value = attn.head_to_batch_dim(value).contiguous()
|
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(
|
|
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
|
)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
if self.train_q_out:
|
|
# linear proj
|
|
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
|
# dropout
|
|
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
|
else:
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class SlicedAttnProcessor:
|
|
def __init__(self, slice_size):
|
|
self.slice_size = slice_size
|
|
|
|
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
residual = hidden_states
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
if input_ndim == 4:
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = (
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
)
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
if attn.group_norm is not None:
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
dim = query.shape[-1]
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
batch_size_attention, query_tokens, _ = query.shape
|
|
hidden_states = torch.zeros(
|
|
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
|
)
|
|
|
|
for i in range(batch_size_attention // self.slice_size):
|
|
start_idx = i * self.slice_size
|
|
end_idx = (i + 1) * self.slice_size
|
|
|
|
query_slice = query[start_idx:end_idx]
|
|
key_slice = key[start_idx:end_idx]
|
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
|
|
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
|
|
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
|
|
|
hidden_states[start_idx:end_idx] = attn_slice
|
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
if input_ndim == 4:
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
if attn.residual_connection:
|
|
hidden_states = hidden_states + residual
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
return hidden_states
|
|
|
|
|
|
class SlicedAttnAddedKVProcessor:
|
|
def __init__(self, slice_size):
|
|
self.slice_size = slice_size
|
|
|
|
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
residual = hidden_states
|
|
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
if encoder_hidden_states is None:
|
|
encoder_hidden_states = hidden_states
|
|
elif attn.norm_cross:
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
dim = query.shape[-1]
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
|
|
|
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
|
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
|
|
|
if not attn.only_cross_attention:
|
|
key = attn.to_k(hidden_states)
|
|
value = attn.to_v(hidden_states)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
|
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
|
else:
|
|
key = encoder_hidden_states_key_proj
|
|
value = encoder_hidden_states_value_proj
|
|
|
|
batch_size_attention, query_tokens, _ = query.shape
|
|
hidden_states = torch.zeros(
|
|
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
|
)
|
|
|
|
for i in range(batch_size_attention // self.slice_size):
|
|
start_idx = i * self.slice_size
|
|
end_idx = (i + 1) * self.slice_size
|
|
|
|
query_slice = query[start_idx:end_idx]
|
|
key_slice = key[start_idx:end_idx]
|
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
|
|
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
|
|
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
|
|
|
hidden_states[start_idx:end_idx] = attn_slice
|
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
# linear proj
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
# dropout
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
|
hidden_states = hidden_states + residual
|
|
|
|
return hidden_states
|
|
|
|
|
|
AttentionProcessor = Union[
|
|
AttnProcessor,
|
|
AttnProcessor2_0,
|
|
XFormersAttnProcessor,
|
|
SlicedAttnProcessor,
|
|
AttnAddedKVProcessor,
|
|
SlicedAttnAddedKVProcessor,
|
|
AttnAddedKVProcessor2_0,
|
|
LoRAAttnProcessor,
|
|
LoRAXFormersAttnProcessor,
|
|
LoRAAttnAddedKVProcessor,
|
|
CustomDiffusionAttnProcessor,
|
|
CustomDiffusionXFormersAttnProcessor,
|
|
]
|
|
|
|
|
|
class SpatialNorm(nn.Module):
|
|
"""
|
|
Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
f_channels,
|
|
zq_channels,
|
|
):
|
|
super().__init__()
|
|
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
|
|
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
|
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
|
|
|
|
def forward(self, f, zq):
|
|
f_size = f.shape[-2:]
|
|
zq = F.interpolate(zq, size=f_size, mode="nearest")
|
|
norm_f = self.norm_layer(f)
|
|
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
|
return new_f
|