1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-04-14 20:12:14 +05:30
parent b67fcf2221
commit a923a73a17

View File

@@ -47,10 +47,19 @@ else:
class AttentionModuleMixin:
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
r"""
Set whether to use npu flash attention from `torch_npu` or not.
"""
A mixin class that provides common methods for attention modules.
This mixin adds functionality to set different attention processors, handle attention masks,
compute attention scores, and manage projections.
"""
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
"""
Set whether to use NPU flash attention from `torch_npu` or not.
Args:
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
"""
if use_npu_flash_attention:
processor = AttnProcessorNPU()
@@ -72,14 +81,16 @@ class AttentionModuleMixin:
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
is_flux=False,
) -> None:
r"""
Set whether to use xla flash attention from `torch_xla` or not.
"""
Set whether to use XLA flash attention from `torch_xla` or not.
Args:
use_xla_flash_attention (`bool`):
Whether to use pallas flash attention kernel from `torch_xla` or not.
partition_spec (`Tuple[]`, *optional*):
Specify the partition specification if using SPMD. Otherwise None.
is_flux (`bool`, *optional*, defaults to `False`):
Whether the model is a Flux model.
"""
if use_xla_flash_attention:
if not is_torch_xla_available:
@@ -104,7 +115,7 @@ class AttentionModuleMixin:
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
) -> None:
r"""
"""
Set whether to use memory efficient attention from `xformers` or not.
Args:
@@ -248,7 +259,7 @@ class AttentionModuleMixin:
self.set_processor(processor)
def set_attention_slice(self, slice_size: int) -> None:
r"""
"""
Set the slice size for attention computation.
Args:
@@ -278,7 +289,7 @@ class AttentionModuleMixin:
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor") -> None:
r"""
"""
Set the attention processor to use.
Args:
@@ -298,7 +309,7 @@ class AttentionModuleMixin:
self.processor = processor
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
r"""
"""
Get the attention processor in use.
Args:
@@ -312,9 +323,8 @@ class AttentionModuleMixin:
return self.processor
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
r"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
is the number of heads initialized while constructing the `Attention` class.
"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
@@ -329,14 +339,12 @@ class AttentionModuleMixin:
return tensor
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
r"""
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
the number of heads initialized while constructing the `Attention` class.
"""
Reshape the tensor for multi-head attention processing.
Args:
tensor (`torch.Tensor`): The tensor to reshape.
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
Returns:
`torch.Tensor`: The reshaped tensor.
@@ -358,13 +366,13 @@ class AttentionModuleMixin:
def get_attention_scores(
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""
"""
Compute the attention scores.
Args:
query (`torch.Tensor`): The query tensor.
key (`torch.Tensor`): The key tensor.
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
Returns:
`torch.Tensor`: The attention probabilities/scores.
@@ -405,18 +413,14 @@ class AttentionModuleMixin:
def prepare_attention_mask(
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
) -> torch.Tensor:
r"""
"""
Prepare the attention mask for the attention computation.
Args:
attention_mask (`torch.Tensor`):
The attention mask to prepare.
target_length (`int`):
The target length of the attention mask. This is the length of the attention mask after padding.
batch_size (`int`):
The batch size, which is used to repeat the attention mask.
out_dim (`int`, *optional*, defaults to `3`):
The output dimension of the attention mask. Can be either `3` or `4`.
attention_mask (`torch.Tensor`): The attention mask to prepare.
target_length (`int`): The target length of the attention mask.
batch_size (`int`): The batch size for repeating the attention mask.
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
Returns:
`torch.Tensor`: The prepared attention mask.
@@ -450,9 +454,8 @@ class AttentionModuleMixin:
return attention_mask
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
r"""
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
`Attention` class.
"""
Normalize the encoder hidden states.
Args:
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
@@ -461,7 +464,6 @@ class AttentionModuleMixin:
`torch.Tensor`: The normalized encoder hidden states.
"""
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
if isinstance(self.norm_cross, nn.LayerNorm):
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
elif isinstance(self.norm_cross, nn.GroupNorm):
@@ -480,6 +482,12 @@ class AttentionModuleMixin:
@torch.no_grad()
def fuse_projections(self, fuse=True):
"""
Fuse the query, key, and value projections into a single projection for efficiency.
Args:
fuse (`bool`): Whether to fuse the projections or not.
"""
device = self.to_q.weight.data.device
dtype = self.to_q.weight.data.dtype
@@ -4534,7 +4542,7 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
return hidden_states
class CustomDiffusionAttnProcessor2_0(nn.Module):
class CustomDiffusionAttnProcessorSDPA(nn.Module):
r"""
Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0s memory-efficient scaled
dot-product attention.
@@ -5056,13 +5064,6 @@ class IPAdapterAttnProcessor(nn.Module):
return hidden_states
class IPAdapterAttnProcessor2_0(IPAdapterAttnProcessorSDPA):
def __init__(self, *args, **kwargs) -> None:
deprecation_message = "`IPAdapterAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `IPAdapterAttnProcessorSDPA`"
deprecate("IPAdapterAttnProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class IPAdapterAttnProcessorSDPA(torch.nn.Module):
r"""
Attention processor for IP-Adapter for PyTorch 2.0.
@@ -5527,7 +5528,7 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
return hidden_states
class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
class SD3IPAdapterJointAttnProcessorSDPA(torch.nn.Module):
"""
Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
additional image-based information and timestep embeddings.
@@ -5996,13 +5997,13 @@ class LoRAAttnAddedKVProcessor:
pass
class FluxSingleAttnProcessor2_0(FluxAttnProcessorSDPA):
class FluxSingleAttnProcessor2_0(FluxSingleAttnProcessorSDPA):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
deprecation_message = "`FluxSingleAttnProcessorSDPA` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
super().__init__()
@@ -6171,288 +6172,290 @@ class PAGIdentitySanaLinearAttnProcessorSDPA:
class MochiAttnProcessor2_0(MochiAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`MochiAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `MochiAttnProcessorSDPA`"
deprecation_message = "`MochiAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `MochiAttnProcessorSDPA`"
deprecate("MochiAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class AttnAddedKVProcessor2_0(AttnAddedKVProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`AttnAddedKVAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
deprecation_message = "`AttnAddedKVAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
deprecate("AttnAddedKVAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class JointAttnProcessor2_0(JointAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`JointAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `JointAttnProcessorSDPA`"
deprecation_message = "`JointAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `JointAttnProcessorSDPA`"
deprecate("JointAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class PAGJointAttnProcessor2_0(PAGJointAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`PAGJointAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGJointAttnProcessorSDPA`"
deprecation_message = "`PAGJointAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGJointAttnProcessorSDPA`"
deprecate("PAGJointAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class PAGCFGJointAttnProcessor2_0(PAGCFGJointAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`PAGCFGJointAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGJointAttnProcessorSDPA`"
deprecation_message = "`PAGCFGJointAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGCFGJointAttnProcessorSDPA`"
deprecate("PAGCFGJointAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class FusedJointAttnProcessor2_0(FusedJointAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`FusedJointAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedJointAttnProcessorSDPA`"
deprecation_message = "`FusedJointAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedJointAttnProcessorSDPA`"
deprecate("FusedJointAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class AllegroAttnProcessor2_0(AllegroAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`AllegroAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `AllegroAttnProcessorSDPA`"
deprecation_message = "`AllegroAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `AllegroAttnProcessorSDPA`"
deprecate("AllegroAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class AuraFlowAttnProcessor2_0(AuraFlowAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`AuraFlowAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`"
deprecation_message = "`AuraFlowAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`"
deprecate("AuraFlowAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class FusedAuraFlowAttnProcessor2_0(FusedAuraFlowAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`FusedAuraFlowAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedAuraFlowAttnProcessorSDPA`"
deprecation_message = "`FusedAuraFlowAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedAuraFlowAttnProcessorSDPA`"
deprecate("FusedAuraFlowAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class FluxAttnProcessor2_0(FluxAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`FluxAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessorSDPA`"
deprecation_message = "`FluxAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessorSDPA`"
deprecate("FluxAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class FusedFluxAttnProcessor2_0(FusedFluxAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`FusedFluxAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedFluxAttnProcessorSDPA`"
deprecation_message = "`FusedFluxAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedFluxAttnProcessorSDPA`"
deprecate("FusedFluxAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class CogVideoXAttnProcessor2_0(CogVideoXAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`CogVideoXAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `CogVideoXAttnProcessorSDPA`"
deprecation_message = "`CogVideoXAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `CogVideoXAttnProcessorSDPA`"
deprecate("CogVideoXAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class FusedCogVideoXAttnProcessor2_0(FusedCogVideoXAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`FusedCogVideoXAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedCogVideoXAttnProcessorSDPA`"
deprecation_message = "`FusedCogVideoXAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedCogVideoXAttnProcessorSDPA`"
deprecate("FusedCogVideoXAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class AttnProcessor2_0(AttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`AttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
deprecation_message = "`AttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
deprecate("AttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class XLAFlashAttnProcessor2_0(XLAFlashAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`XLAFlashAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFlashAttnProcessorSDPA`"
deprecation_message = "`XLAFlashAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `XLAFlashAttnProcessorSDPA`"
deprecate("XLAFlashAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class XLAFluxFlashAttnProcessor2_0(XLAFluxFlashAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`XLAFluxFlashAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFluxFlashAttnProcessorSDPA`"
deprecation_message = "`XLAFluxFlashAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `XLAFluxFlashAttnProcessorSDPA`"
deprecate("XLAFluxFlashAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class MochiVaeAttnProcessor2_0(MochiVaeAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`MochiVaeAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `MochiVaeAttnProcessorSDPA`"
deprecation_message = "`MochiVaeAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `MochiVaeAttnProcessorSDPA`"
deprecate("MochiVaeAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class StableAudioAttnProcessor2_0(StableAudioAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`StableAudioAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `StableAudioAttnProcessorSDPA`"
deprecation_message = "`StableAudioAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `StableAudioAttnProcessorSDPA`"
deprecate("StableAudioAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class HunyuanAttnProcessor2_0(HunyuanAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`HunyuanAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `HunyuanAttnProcessorSDPA`"
deprecation_message = "`HunyuanAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `HunyuanAttnProcessorSDPA`"
deprecate("HunyuanAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class FusedHunyuanAttnProcessor2_0(FusedHunyuanAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`FusedHunyuanAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedHunyuanAttnProcessorSDPA`"
deprecation_message = "`FusedHunyuanAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedHunyuanAttnProcessorSDPA`"
deprecate("FusedHunyuanAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class PAGHunyuanAttnProcessor2_0(PAGHunyuanAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`PAGHunyuanAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGHunyuanAttnProcessorSDPA`"
deprecation_message = "`PAGHunyuanAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGHunyuanAttnProcessorSDPA`"
deprecate("PAGHunyuanAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class PAGCFGHunyuanAttnProcessor2_0(PAGCFGHunyuanAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`PAGCFGHunyuanAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGHunyuanAttnProcessorSDPA`"
deprecation_message = "`PAGCFGHunyuanAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGCFGHunyuanAttnProcessorSDPA`"
deprecate("PAGCFGHunyuanAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class LuminaAttnProcessor2_0(LuminaAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`LuminaAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LuminaAttnProcessorSDPA`"
deprecation_message = "`LuminaAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `LuminaAttnProcessorSDPA`"
deprecate("LuminaAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class FusedAttnProcessor2_0(FusedAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`FusedAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedAttnProcessorSDPA`"
deprecation_message = "`FusedAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedAttnProcessorSDPA`"
deprecate("FusedAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class PAGIdentitySelfAttnProcessor2_0(PAGIdentitySelfAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`PAGIdentitySelfAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySelfAttnProcessorSDPA`"
deprecation_message = "`PAGIdentitySelfAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGIdentitySelfAttnProcessorSDPA`"
deprecate("PAGIdentitySelfAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class PAGCFGIdentitySelfAttnProcessor2_0(PAGCFGIdentitySelfAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`PAGCFGIdentitySelfAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGIdentitySelfAttnProcessorSDPA`"
deprecation_message = "`PAGCFGIdentitySelfAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGCFGIdentitySelfAttnProcessorSDPA`"
deprecate("PAGCFGIdentitySelfAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class SanaMultiscaleAttnProcessor2_0(SanaMultiscaleAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`SanaMultiscaleAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaMultiscaleAttnProcessorSDPA`"
deprecation_message = "`SanaMultiscaleAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `SanaMultiscaleAttnProcessorSDPA`"
deprecate("SanaMultiscaleAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class LoRAAttnProcessor2_0(LoRAAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`LoRAAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LoRAAttnProcessorSDPA`"
deprecation_message = "`LoRAAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `LoRAAttnProcessorSDPA`"
deprecate("LoRAAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class SanaLinearAttnProcessor2_0(SanaLinearAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`SanaLinearAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaLinearAttnProcessorSDPA`"
deprecation_message = "`SanaLinearAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `SanaLinearAttnProcessorSDPA`"
deprecate("SanaLinearAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class PAGCFGSanaLinearAttnProcessor2_0(PAGCFGSanaLinearAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`PAGCFGSanaLinearAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGSanaLinearAttnProcessorSDPA`"
deprecation_message = "`PAGCFGSanaLinearAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGCFGSanaLinearAttnProcessorSDPA`"
deprecate("PAGCFGSanaLinearAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class PAGIdentitySanaLinearAttnProcessor2_0(PAGIdentitySanaLinearAttnProcessorSDPA):
def __init__(self, *args, **kwargs):
deprecation_message = "`PAGIdentitySanaLinearAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySanaLinearAttnProcessorSDPA`"
deprecation_message = "`PAGIdentitySanaLinearAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGIdentitySanaLinearAttnProcessorSDPA`"
deprecate("PAGIdentitySanaLinearAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
class IPAdapterAttnProcessor2_0(IPAdapterAttnProcessorSDPA):
def __init__(self, *args, **kwargs) -> None:
deprecation_message = "`IPAdapterAttnProcessorSDPA` is deprecated and this will be removed in a future version. Please use `IPAdapterAttnProcessorSDPA`"
deprecate("IPAdapterAttnProcessor2_0", "1.0.0", deprecation_message)
super().__init__(*args, **kwargs)
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
AttnAddedKVProcessorSDPA,
XFormersAttnAddedKVProcessor,
)
CROSS_ATTENTION_PROCESSORS = (
AttnProcessor,
AttnProcessor2_0,
AttnProcessorSDPA,
XFormersAttnProcessor,
SlicedAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessorSDPA,
FluxIPAdapterJointAttnProcessorSDPA,
)
AttentionProcessor = Union[
AttnProcessor,
CustomDiffusionAttnProcessor,
AttnAddedKVProcessor,
AttnAddedKVProcessor2_0,
JointAttnProcessor2_0,
PAGJointAttnProcessor2_0,
PAGCFGJointAttnProcessor2_0,
FusedJointAttnProcessor2_0,
AllegroAttnProcessor2_0,
AuraFlowAttnProcessor2_0,
FusedAuraFlowAttnProcessor2_0,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
FusedFluxAttnProcessor2_0_NPU,
CogVideoXAttnProcessor2_0,
FusedCogVideoXAttnProcessor2_0,
AttnAddedKVProcessorSDPA,
JointAttnProcessorSDPA,
PAGJointAttnProcessorSDPA,
PAGCFGJointAttnProcessorSDPA,
FusedJointAttnProcessorSDPA,
AllegroAttnProcessorSDPA,
AuraFlowAttnProcessorSDPA,
FusedAuraFlowAttnProcessorSDPA,
FluxAttnProcessorSDPA,
FluxAttnProcessorSDPA_NPU,
FusedFluxAttnProcessorSDPA,
FusedFluxAttnProcessorSDPA_NPU,
CogVideoXAttnProcessorSDPA,
FusedCogVideoXAttnProcessorSDPA,
XFormersAttnAddedKVProcessor,
XFormersAttnProcessor,
XLAFlashAttnProcessor2_0,
XLAFlashAttnProcessorSDPA,
AttnProcessorNPU,
AttnProcessor2_0,
MochiVaeAttnProcessor2_0,
MochiAttnProcessor2_0,
StableAudioAttnProcessor2_0,
HunyuanAttnProcessor2_0,
FusedHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
LuminaAttnProcessor2_0,
FusedAttnProcessor2_0,
AttnProcessorSDPA,
MochiVaeAttnProcessorSDPA,
MochiAttnProcessorSDPA,
StableAudioAttnProcessorSDPA,
HunyuanAttnProcessorSDPA,
FusedHunyuanAttnProcessorSDPA,
PAGHunyuanAttnProcessorSDPA,
PAGCFGHunyuanAttnProcessorSDPA,
LuminaAttnProcessorSDPA,
FusedAttnProcessorSDPA,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionAttnProcessorSDPA,
SlicedAttnProcessor,
SlicedAttnAddedKVProcessor,
SanaLinearAttnProcessor2_0,
PAGCFGSanaLinearAttnProcessor2_0,
PAGIdentitySanaLinearAttnProcessor2_0,
SanaLinearAttnProcessorSDPA,
PAGCFGSanaLinearAttnProcessorSDPA,
PAGIdentitySanaLinearAttnProcessorSDPA,
SanaMultiscaleLinearAttention,
SanaMultiscaleAttnProcessor2_0,
SanaMultiscaleAttnProcessorSDPA,
SanaMultiscaleAttentionProjection,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterAttnProcessorSDPA,
IPAdapterXFormersAttnProcessor,
SD3IPAdapterJointAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
PAGCFGIdentitySelfAttnProcessor2_0,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
SD3IPAdapterJointAttnProcessorSDPA,
PAGIdentitySelfAttnProcessorSDPA,
PAGCFGIdentitySelfAttnProcessorSDPA,
LoRAAttnProcessorSDPA,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
]