mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
handle ip adapter params correctly
This commit is contained in:
@@ -40,8 +40,6 @@ if is_transformers_available():
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
@@ -867,6 +865,9 @@ class FluxIPAdapterMixin:
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
|
||||
from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
|
||||
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
@@ -886,9 +887,9 @@ class FluxIPAdapterMixin:
|
||||
# restore original Transformer attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.transformer.attn_processors.items():
|
||||
attn_processor_class = FluxAttnProcessor2_0()
|
||||
attn_processor_class = FluxAttnProcessor()
|
||||
attn_procs[name] = (
|
||||
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
|
||||
attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
|
||||
)
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
@@ -87,9 +87,7 @@ class FluxTransformer2DLoadersMixin:
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
from ..models.attention_processor import (
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
@@ -121,7 +119,7 @@ class FluxTransformer2DLoadersMixin:
|
||||
else:
|
||||
cross_attention_dim = self.config.joint_attention_dim
|
||||
hidden_size = self.inner_dim
|
||||
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
||||
attn_processor_class = FluxIPAdapterAttnProcessor
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
|
||||
@@ -2501,152 +2501,6 @@ class FusedFluxAttnProcessor2_0_NPU:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
"""Flux Attention processor for IP-Adapter."""
|
||||
|
||||
def __init__(
|
||||
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
|
||||
for _ in range(len(num_tokens))
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
ip_hidden_states: Optional[List[torch.Tensor]] = None,
|
||||
ip_adapter_masks: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
hidden_states_query_proj = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
# IP-adapter
|
||||
ip_query = hidden_states_query_proj
|
||||
ip_attn_output = torch.zeros_like(hidden_states)
|
||||
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
||||
):
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
current_ip_hidden_states = F.scaled_dot_product_attention(
|
||||
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
||||
ip_attn_output += scale * current_ip_hidden_states
|
||||
|
||||
return hidden_states, encoder_hidden_states, ip_attn_output
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
@@ -6019,6 +5873,16 @@ class FusedFluxAttnProcessor2_0:
|
||||
return FluxAttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
class FluxIPAdapterJointAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
|
||||
deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
|
||||
from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||
|
||||
return FluxIPAdapterAttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
ADDED_KV_ATTENTION_PROCESSORS = (
|
||||
AttnAddedKVProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -241,7 +241,9 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.nn.functional(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
@@ -354,6 +356,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user