From b8f7fe61e1c5136cb8f88ee7ebe14c7b7c95fb13 Mon Sep 17 00:00:00 2001 From: Aryan Date: Mon, 14 Jul 2025 08:21:47 +0200 Subject: [PATCH] handle ip adapter params correctly --- src/diffusers/loaders/ip_adapter.py | 9 +- src/diffusers/loaders/transformer_flux.py | 6 +- src/diffusers/models/attention_processor.py | 156 ++---------------- .../models/transformers/transformer_flux.py | 14 +- 4 files changed, 29 insertions(+), 156 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index e05d53687a..dca4758ba0 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -40,8 +40,6 @@ if is_transformers_available(): from ..models.attention_processor import ( AttnProcessor, AttnProcessor2_0, - FluxAttnProcessor2_0, - FluxIPAdapterJointAttnProcessor2_0, IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, @@ -867,6 +865,9 @@ class FluxIPAdapterMixin: >>> ... ``` """ + # TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level + from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor + # remove CLIP image encoder if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: self.image_encoder = None @@ -886,9 +887,9 @@ class FluxIPAdapterMixin: # restore original Transformer attention processors layers attn_procs = {} for name, value in self.transformer.attn_processors.items(): - attn_processor_class = FluxAttnProcessor2_0() + attn_processor_class = FluxAttnProcessor() attn_procs[name] = ( - attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__() + attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__() ) self.transformer.set_attn_processor(attn_procs) diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py index af03d09029..0873e8edd0 100644 --- a/src/diffusers/loaders/transformer_flux.py +++ b/src/diffusers/loaders/transformer_flux.py @@ -87,9 +87,7 @@ class FluxTransformer2DLoadersMixin: return image_projection def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT): - from ..models.attention_processor import ( - FluxIPAdapterJointAttnProcessor2_0, - ) + from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor if low_cpu_mem_usage: if is_accelerate_available(): @@ -121,7 +119,7 @@ class FluxTransformer2DLoadersMixin: else: cross_attention_dim = self.config.joint_attention_dim hidden_size = self.inner_dim - attn_processor_class = FluxIPAdapterJointAttnProcessor2_0 + attn_processor_class = FluxIPAdapterAttnProcessor num_image_text_embeds = [] for state_dict in state_dicts: if "proj.weight" in state_dict["image_proj"]: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 2306bdbc9d..e64bd45eb4 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2501,152 +2501,6 @@ class FusedFluxAttnProcessor2_0_NPU: return hidden_states -class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module): - """Flux Attention processor for IP-Adapter.""" - - def __init__( - self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None - ): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - - if not isinstance(num_tokens, (tuple, list)): - num_tokens = [num_tokens] - - if not isinstance(scale, list): - scale = [scale] * len(num_tokens) - if len(scale) != len(num_tokens): - raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") - self.scale = scale - - self.to_k_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - self.to_v_ip = nn.ModuleList( - [ - nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) - for _ in range(len(num_tokens)) - ] - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ip_hidden_states: Optional[List[torch.Tensor]] = None, - ip_adapter_masks: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - hidden_states_query_proj = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - hidden_states_query_proj = attn.norm_q(hidden_states_query_proj) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - # IP-adapter - ip_query = hidden_states_query_proj - ip_attn_output = torch.zeros_like(hidden_states) - - for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( - ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip - ): - ip_key = to_k_ip(current_ip_hidden_states) - ip_value = to_v_ip(current_ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) - current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) - ip_attn_output += scale * current_ip_hidden_states - - return hidden_states, encoder_hidden_states, ip_attn_output - else: - return hidden_states - - class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -6019,6 +5873,16 @@ class FusedFluxAttnProcessor2_0: return FluxAttnProcessor(*args, **kwargs) +class FluxIPAdapterJointAttnProcessor2_0: + def __new__(cls, *args, **kwargs): + deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`" + deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxIPAdapterAttnProcessor + + return FluxIPAdapterAttnProcessor(*args, **kwargs) + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 8218b2ae6e..7898bdb1f0 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import inspect from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -241,7 +241,9 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module): query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = torch.nn.functional(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -354,6 +356,14 @@ class FluxAttention(torch.nn.Module, AttentionModuleMixin): image_rotary_emb: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)