mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Multi-image masking for single IP Adapter (#7499)
* Support multiimage masking --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from importlib import import_module
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -2195,15 +2195,33 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if ip_adapter_masks is not None:
|
||||
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
|
||||
if not isinstance(ip_adapter_masks, List):
|
||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
||||
raise ValueError(
|
||||
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
|
||||
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
||||
)
|
||||
if len(ip_adapter_masks) != len(self.scale):
|
||||
raise ValueError(
|
||||
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
|
||||
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
||||
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
||||
f"({len(ip_hidden_states)})"
|
||||
)
|
||||
else:
|
||||
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
||||
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
||||
raise ValueError(
|
||||
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
||||
"[1, num_images_for_ip_adapter, height, width]."
|
||||
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
||||
)
|
||||
if mask.shape[1] != ip_state.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of masks ({mask.shape[1]}) does not match "
|
||||
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
||||
)
|
||||
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of masks ({mask.shape[1]}) does not match "
|
||||
f"number of scales ({len(scale)}) at index {index}"
|
||||
)
|
||||
else:
|
||||
ip_adapter_masks = [None] * len(self.scale)
|
||||
|
||||
@@ -2211,26 +2229,44 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
||||
):
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
||||
|
||||
if mask is not None:
|
||||
mask_downsample = IPAdapterMaskProcessor.downsample(
|
||||
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
|
||||
)
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
|
||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
||||
current_num_images = mask.shape[1]
|
||||
for i in range(current_num_images):
|
||||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
||||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
||||
|
||||
current_ip_hidden_states = current_ip_hidden_states * mask_downsample
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
_current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
|
||||
|
||||
mask_downsample = IPAdapterMaskProcessor.downsample(
|
||||
mask[:, i, :, :],
|
||||
batch_size,
|
||||
_current_ip_hidden_states.shape[1],
|
||||
_current_ip_hidden_states.shape[2],
|
||||
)
|
||||
|
||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
||||
|
||||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
||||
else:
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
current_ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
@@ -2369,15 +2405,33 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if ip_adapter_masks is not None:
|
||||
if not isinstance(ip_adapter_masks, torch.Tensor) or ip_adapter_masks.ndim != 4:
|
||||
if not isinstance(ip_adapter_masks, List):
|
||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
||||
raise ValueError(
|
||||
" ip_adapter_mask should be a tensor with shape [num_ip_adapter, 1, height, width]."
|
||||
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
||||
)
|
||||
if len(ip_adapter_masks) != len(self.scale):
|
||||
raise ValueError(
|
||||
f"Number of ip_adapter_masks ({len(ip_adapter_masks)}) must match number of IP-Adapters ({len(self.scale)})"
|
||||
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
||||
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
||||
f"({len(ip_hidden_states)})"
|
||||
)
|
||||
else:
|
||||
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
||||
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
||||
raise ValueError(
|
||||
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
||||
"[1, num_images_for_ip_adapter, height, width]."
|
||||
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
||||
)
|
||||
if mask.shape[1] != ip_state.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of masks ({mask.shape[1]}) does not match "
|
||||
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
||||
)
|
||||
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of masks ({mask.shape[1]}) does not match "
|
||||
f"number of scales ({len(scale)}) at index {index}"
|
||||
)
|
||||
else:
|
||||
ip_adapter_masks = [None] * len(self.scale)
|
||||
|
||||
@@ -2385,33 +2439,57 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
||||
):
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
||||
|
||||
if mask is not None:
|
||||
mask_downsample = IPAdapterMaskProcessor.downsample(
|
||||
mask, batch_size, current_ip_hidden_states.shape[1], current_ip_hidden_states.shape[2]
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
|
||||
current_num_images = mask.shape[1]
|
||||
for i in range(current_num_images):
|
||||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
||||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
_current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
_current_ip_hidden_states = _current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
|
||||
|
||||
mask_downsample = IPAdapterMaskProcessor.downsample(
|
||||
mask[:, i, :, :],
|
||||
batch_size,
|
||||
_current_ip_hidden_states.shape[1],
|
||||
_current_ip_hidden_states.shape[2],
|
||||
)
|
||||
|
||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
||||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
||||
else:
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
||||
|
||||
current_ip_hidden_states = current_ip_hidden_states * mask_downsample
|
||||
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user