mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Support for cross-attention bias / mask (#2634)
* Cross-attention masks prefer qualified symbol, fix accidental Optional prefer qualified symbol in AttentionProcessor prefer qualified symbol in embeddings.py qualified symbol in transformed_2d qualify FloatTensor in unet_2d_blocks move new transformer_2d params attention_mask, encoder_attention_mask to the end of the section which is assumed (e.g. by functions such as checkpoint()) to have a stable positional param interface. regard return_dict as a special-case which is assumed to be injected separately from positional params (e.g. by create_custom_forward()). move new encoder_attention_mask param to end of CrossAttn block interfaces and Unet2DCondition interface, to maintain positional param interface. regenerate modeling_text_unet.py remove unused import unet_2d_condition encoder_attention_mask docs Co-authored-by: Pedro Cuenca <pedro@huggingface.co> versatile_diffusion/modeling_text_unet.py encoder_attention_mask docs Co-authored-by: Pedro Cuenca <pedro@huggingface.co> transformer_2d encoder_attention_mask docs Co-authored-by: Pedro Cuenca <pedro@huggingface.co> unet_2d_blocks.py: add parameter name comments Co-authored-by: Pedro Cuenca <pedro@huggingface.co> revert description. bool-to-bias treatment happens in unet_2d_condition only. comment parameter names fix copies, style * encoder_attention_mask for SimpleCrossAttnDownBlock2D, SimpleCrossAttnUpBlock2D * encoder_attention_mask for UNetMidBlock2DSimpleCrossAttn * support attention_mask, encoder_attention_mask in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D, KAttentionBlock. fix binding of attention_mask, cross_attention_kwargs params in KCrossAttnDownBlock2D, KCrossAttnUpBlock2D checkpoint invocations. * fix mistake made during merge conflict resolution * regenerate versatile_diffusion * pass time embedding into checkpointed attention invocation * always assume encoder_attention_mask is a mask (i.e. not a bias). * style, fix-copies * add tests for cross-attention masks * add test for padding of attention mask * explain mask's query_tokens dim. fix explanation about broadcasting over channels; we actually broadcast over query tokens * support both masks and biases in Transformer2DModel#forward. document behaviour * fix-copies * delete attention_mask docs on the basis I never tested self-attention masking myself. not comfortable explaining it, since I don't actually understand how a self-attn mask can work in its current form: the key length will be different in every ResBlock (we don't downsample the mask when we downsample the image). * review feedback: the standard Unet blocks shouldn't pass temb to attn (only to resnet). remove from KCrossAttnDownBlock2D,KCrossAttnUpBlock2D#forward. * remove encoder_attention_mask param from SimpleCrossAttn{Up,Down}Block2D,UNetMidBlock2DSimpleCrossAttn, and mask-choice in those blocks' #forward, on the basis that they only do one type of attention, so the consumer can pass whichever type of attention_mask is appropriate. * put attention mask padding back to how it was (since the SD use-case it enabled wasn't important, and it breaks the original unclip use-case). disable the test which was added. * fix-copies * style * fix-copies * put encoder_attention_mask param back into Simple block forward interfaces, to ensure consistency of forward interface. * restore passing of emb to KAttentionBlock#forward, on the basis that removal caused test failures. restore also the passing of emb to checkpointed calls to KAttentionBlock#forward. * make simple unet2d blocks use encoder_attention_mask, but only when attention_mask is None. this should fix UnCLIP compatibility. * fix copies
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -120,13 +120,13 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
timestep=None,
|
||||
cross_attention_kwargs=None,
|
||||
class_labels=None,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 1. Self-Attention
|
||||
@@ -155,8 +155,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
# TODO (Birch-San): Here we should prepare the encoder_attention mask correctly
|
||||
# prepare attention mask here
|
||||
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
|
||||
@@ -380,7 +380,13 @@ class Attention(nn.Module):
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
if attention_mask.shape[-1] != target_length:
|
||||
current_length: int = attention_mask.shape[-1]
|
||||
if current_length > target_length:
|
||||
# we *could* trim the mask with:
|
||||
# attention_mask = attention_mask[:,:target_length]
|
||||
# but this is weird enough that it's more likely to be a mistake than a shortcut
|
||||
raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).")
|
||||
elif current_length < target_length:
|
||||
if attention_mask.device.type == "mps":
|
||||
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
||||
# Instead, we can manually construct the padding tensor.
|
||||
@@ -388,6 +394,10 @@ class Attention(nn.Module):
|
||||
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
||||
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
||||
else:
|
||||
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
||||
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
||||
# remaining_length: int = target_length - current_length
|
||||
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
|
||||
if out_dim == 3:
|
||||
@@ -820,7 +830,13 @@ class XFormersAttnProcessor:
|
||||
def __init__(self, attention_op: Optional[Callable] = None):
|
||||
self.attention_op = attention_op
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
@@ -829,11 +845,20 @@ class XFormersAttnProcessor:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
batch_size, key_tokens, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size)
|
||||
if attention_mask is not None:
|
||||
# expand our mask's singleton query_tokens dimension:
|
||||
# [batch*heads, 1, key_tokens] ->
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# so that it can be added as a bias onto the attention scores that xformers computes:
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
||||
_, query_tokens, _ = hidden_states.shape
|
||||
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
@@ -352,7 +352,7 @@ class LabelEmbedding(nn.Module):
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, force_drop_ids=None):
|
||||
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (self.training and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -213,11 +213,13 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
timestep=None,
|
||||
class_labels=None,
|
||||
cross_attention_kwargs=None,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -228,11 +230,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
|
||||
conditioning.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional* ).
|
||||
Cross-attention mask, applied to encoder_hidden_states. Two formats supported:
|
||||
Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0
|
||||
= keep, -10000 = discard.
|
||||
If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -241,6 +249,29 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
@@ -264,7 +295,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -558,14 +558,22 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
):
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -659,16 +667,34 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
if attention_mask is None:
|
||||
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
||||
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
||||
else:
|
||||
# when attention_mask is defined: we don't even check for encoder_attention_mask.
|
||||
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
|
||||
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
|
||||
# then we can simplify this whole if/else block to:
|
||||
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
||||
mask = attention_mask
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
# attn
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -850,9 +876,14 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
@@ -867,33 +898,32 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
use_reentrant=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
)[0]
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
None, # timestep
|
||||
None, # class_labels
|
||||
cross_attention_kwargs,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -1501,11 +1531,28 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
output_states = ()
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
if attention_mask is None:
|
||||
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
||||
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
||||
else:
|
||||
# when attention_mask is defined: we don't even check for encoder_attention_mask.
|
||||
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
|
||||
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
|
||||
# then we can simplify this whole if/else block to:
|
||||
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
||||
mask = attention_mask
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -1523,6 +1570,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
mask,
|
||||
cross_attention_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
@@ -1531,7 +1579,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -1690,7 +1738,13 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
output_states = ()
|
||||
|
||||
@@ -1706,29 +1760,23 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
cross_attention_kwargs,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
cross_attention_kwargs,
|
||||
)
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
cross_attention_kwargs,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
@@ -1737,6 +1785,7 @@ class KCrossAttnDownBlock2D(nn.Module):
|
||||
emb=temb,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
if self.downsamplers is None:
|
||||
@@ -1916,15 +1965,15 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
cross_attention_kwargs=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
@@ -1942,33 +1991,32 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
use_reentrant=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
)[0]
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
None, # timestep
|
||||
None, # class_labels
|
||||
cross_attention_kwargs,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -2594,15 +2642,28 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
if attention_mask is None:
|
||||
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
||||
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
||||
else:
|
||||
# when attention_mask is defined: we don't even check for encoder_attention_mask.
|
||||
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
|
||||
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
|
||||
# then we can simplify this whole if/else block to:
|
||||
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
||||
mask = attention_mask
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# resnet
|
||||
# pop res hidden states
|
||||
@@ -2626,6 +2687,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
mask,
|
||||
cross_attention_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
@@ -2634,7 +2696,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
@@ -2811,13 +2873,14 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
cross_attention_kwargs=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[-1]
|
||||
if res_hidden_states_tuple is not None:
|
||||
@@ -2835,29 +2898,23 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
cross_attention_kwargs,
|
||||
use_reentrant=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
cross_attention_kwargs,
|
||||
)[0]
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
cross_attention_kwargs,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
@@ -2866,6 +2923,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
emb=temb,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
@@ -2944,11 +3002,14 @@ class KAttentionBlock(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
emb=None,
|
||||
attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
# TODO: mark emb as non-optional (self.norm2 requires it).
|
||||
# requires assessing impact of change to positional param interface.
|
||||
emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
@@ -2962,6 +3023,7 @@ class KAttentionBlock(nn.Module):
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
attn_output = self._to_4d(attn_output, height, weight)
|
||||
@@ -2976,6 +3038,7 @@ class KAttentionBlock(nn.Module):
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
attn_output = self._to_4d(attn_output, height, weight)
|
||||
|
||||
@@ -618,6 +618,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
@@ -625,6 +626,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||
encoder_attention_mask (`torch.Tensor`):
|
||||
(batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
|
||||
discard. Mask will be converted into a bias, which adds large negative values to attention scores
|
||||
corresponding to "discard" tokens.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
@@ -651,11 +656,27 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# prepare attention_mask
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
@@ -727,6 +748,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
@@ -752,6 +774,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
if mid_block_additional_residual is not None:
|
||||
@@ -778,6 +801,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
|
||||
@@ -721,6 +721,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
@@ -728,6 +729,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||
encoder_attention_mask (`torch.Tensor`):
|
||||
(batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False =
|
||||
discard. Mask will be converted into a bias, which adds large negative values to attention scores
|
||||
corresponding to "discard" tokens.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
@@ -754,11 +759,27 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# prepare attention_mask
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
@@ -830,6 +851,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
@@ -855,6 +877,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
if mid_block_additional_residual is not None:
|
||||
@@ -881,6 +904,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
@@ -1188,9 +1212,14 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
@@ -1205,33 +1234,32 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
use_reentrant=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
)[0]
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
None, # timestep
|
||||
None, # class_labels
|
||||
cross_attention_kwargs,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -1414,15 +1442,15 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
cross_attention_kwargs=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
@@ -1440,33 +1468,32 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
use_reentrant=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
cross_attention_kwargs,
|
||||
)[0]
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
None, # timestep
|
||||
None, # class_labels
|
||||
cross_attention_kwargs,
|
||||
attention_mask,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
@@ -1564,14 +1591,22 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
):
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -1666,16 +1701,34 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(
|
||||
self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
|
||||
if attention_mask is None:
|
||||
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
|
||||
mask = None if encoder_hidden_states is None else encoder_attention_mask
|
||||
else:
|
||||
# when attention_mask is defined: we don't even check for encoder_attention_mask.
|
||||
# this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
|
||||
# TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
|
||||
# then we can simplify this whole if/else block to:
|
||||
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
|
||||
mask = attention_mask
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
# attn
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, LoRAAttnProcessor
|
||||
@@ -418,6 +419,76 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert processor.is_run
|
||||
assert processor.number == 123
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[torch.bool],
|
||||
[torch.long],
|
||||
[torch.float],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_model_xattn_mask(self, mask_dtype):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
cond = inputs_dict["encoder_hidden_states"]
|
||||
with torch.no_grad():
|
||||
full_cond_out = model(**inputs_dict).sample
|
||||
assert full_cond_out is not None
|
||||
|
||||
keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype)
|
||||
full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample
|
||||
assert full_cond_keepallmask_out.allclose(
|
||||
full_cond_out
|
||||
), "a 'keep all' mask should give the same result as no mask"
|
||||
|
||||
trunc_cond = cond[:, :-1, :]
|
||||
trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample
|
||||
assert not trunc_cond_out.allclose(
|
||||
full_cond_out
|
||||
), "discarding the last token from our cond should change the result"
|
||||
|
||||
batch, tokens, _ = cond.shape
|
||||
mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype)
|
||||
masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample
|
||||
assert masked_cond_out.allclose(
|
||||
trunc_cond_out
|
||||
), "masking the last token from our cond should be equivalent to truncating that token out of the condition"
|
||||
|
||||
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
|
||||
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
|
||||
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric.
|
||||
# maybe it's fine that this only works for the unclip use-case.
|
||||
@mark.skip(
|
||||
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
|
||||
)
|
||||
def test_model_xattn_padding(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
cond = inputs_dict["encoder_hidden_states"]
|
||||
with torch.no_grad():
|
||||
full_cond_out = model(**inputs_dict).sample
|
||||
assert full_cond_out is not None
|
||||
|
||||
batch, tokens, _ = cond.shape
|
||||
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
|
||||
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
|
||||
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
|
||||
|
||||
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
|
||||
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
|
||||
assert trunc_mask_out.allclose(
|
||||
keeplast_out
|
||||
), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
|
||||
|
||||
def test_lora_processors(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user