From 251bb619250900c60273621a071bef22b1971dca Mon Sep 17 00:00:00 2001 From: Aryan Date: Wed, 16 Jul 2025 12:50:54 +0200 Subject: [PATCH] minify and deprecate npu/xla processors --- src/diffusers/models/attention_processor.py | 396 +++----------------- 1 file changed, 61 insertions(+), 335 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e64bd45eb4..990245de17 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -2272,235 +2272,6 @@ class FusedAuraFlowAttnProcessor2_0: return hidden_states -class FluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - deprecation_message = ( - "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " - "alternative solution to use NPU Flash Attention will be provided in the future." - ) - deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - -class FusedFluxAttnProcessor2_0_NPU: - """Attention processor used typically in processing the SD3-like self-attention projections.""" - - def __init__(self): - deprecation_message = ( - "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " - "alternative solution to use NPU Flash Attention will be provided in the future." - ) - deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU" - ) - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - qkv = attn.to_qkv(hidden_states) - split_size = qkv.shape[-1] // 3 - query, key, value = torch.split(qkv, split_size, dim=-1) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - # `context` projections. - if encoder_hidden_states is not None: - encoder_qkv = attn.to_added_qkv(encoder_hidden_states) - split_size = encoder_qkv.shape[-1] // 3 - ( - encoder_hidden_states_query_proj, - encoder_hidden_states_key_proj, - encoder_hidden_states_value_proj, - ) = torch.split(encoder_qkv, split_size, dim=-1) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - if query.dtype in (torch.float16, torch.bfloat16): - hidden_states = torch_npu.npu_fusion_attention( - query, - key, - value, - attn.heads, - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(query.shape[-1]), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - class CogVideoXAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on @@ -3130,112 +2901,6 @@ class XLAFlashAttnProcessor2_0: return hidden_states -class XLAFluxFlashAttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. - """ - - def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None): - deprecation_message = ( - "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An " - "alternative solution to using XLA Flash Attention will be provided in the future." - ) - deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError( - "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." - ) - if is_torch_xla_version("<", "2.3"): - raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") - if is_spmd() and is_torch_xla_version("<", "2.4"): - raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") - self.partition_spec = partition_spec - - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - from .embeddings import apply_rotary_emb - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - query /= math.sqrt(head_dim) - hidden_states = flash_attention(query, key, value, causal=False) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states - - class MochiVaeAttnProcessor2_0: r""" Attention processor used in Mochi VAE. @@ -5883,6 +5548,67 @@ class FluxIPAdapterJointAttnProcessor2_0: return FluxIPAdapterAttnProcessor(*args, **kwargs) +class FluxAttnProcessor2_0_NPU: + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." + ) + deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) + + from .transformers.transformer_flux import FluxAttnProcessor + + processor = FluxAttnProcessor() + processor._attention_backend = "_native_npu" + return processor + + +class FusedFluxAttnProcessor2_0_NPU: + def __new__(self): + deprecation_message = ( + "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An " + "alternative solution to use NPU Flash Attention will be provided in the future." + ) + deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False) + + from .transformers.transformer_flux import FluxAttnProcessor + + processor = FluxAttnProcessor() + processor._attention_backend = "_fused_npu" + return processor + + +class XLAFluxFlashAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`. + """ + + def __new__(cls, *args, **kwargs): + deprecation_message = ( + "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An " + "alternative solution to using XLA Flash Attention will be provided in the future." + ) + deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False) + + if is_torch_xla_version("<", "2.3"): + raise ImportError("XLA flash attention requires torch_xla version >= 2.3.") + if is_spmd() and is_torch_xla_version("<", "2.4"): + raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.") + + from .transformers.transformer_flux import FluxAttnProcessor + + if len(args) > 0 or kwargs.get("partition_spec", None) is not None: + deprecation_message = ( + "partition_spec was not used in the processor implementation when it was added. Passing it " + "is a no-op and support for it will be removed." + ) + deprecate("partition_spec", "1.0.0", deprecation_message) + + processor = FluxAttnProcessor(*args, **kwargs) + processor._attention_backend = "_native_xla" + return processor + + ADDED_KV_ATTENTION_PROCESSORS = ( AttnAddedKVProcessor, SlicedAttnAddedKVProcessor,