mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -47,10 +47,19 @@ else:
|
||||
|
||||
|
||||
class AttentionModuleMixin:
|
||||
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
||||
r"""
|
||||
Set whether to use npu flash attention from `torch_npu` or not.
|
||||
"""
|
||||
A mixin class that provides common methods for attention modules.
|
||||
|
||||
This mixin adds functionality to set different attention processors, handle attention masks,
|
||||
compute attention scores, and manage projections.
|
||||
"""
|
||||
|
||||
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
|
||||
"""
|
||||
Set whether to use NPU flash attention from `torch_npu` or not.
|
||||
|
||||
Args:
|
||||
use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
|
||||
"""
|
||||
if use_npu_flash_attention:
|
||||
processor = AttnProcessorNPU()
|
||||
@@ -72,14 +81,16 @@ class AttentionModuleMixin:
|
||||
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
||||
is_flux=False,
|
||||
) -> None:
|
||||
r"""
|
||||
Set whether to use xla flash attention from `torch_xla` or not.
|
||||
"""
|
||||
Set whether to use XLA flash attention from `torch_xla` or not.
|
||||
|
||||
Args:
|
||||
use_xla_flash_attention (`bool`):
|
||||
Whether to use pallas flash attention kernel from `torch_xla` or not.
|
||||
partition_spec (`Tuple[]`, *optional*):
|
||||
Specify the partition specification if using SPMD. Otherwise None.
|
||||
is_flux (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model is a Flux model.
|
||||
"""
|
||||
if use_xla_flash_attention:
|
||||
if not is_torch_xla_available:
|
||||
@@ -104,7 +115,7 @@ class AttentionModuleMixin:
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
r"""
|
||||
"""
|
||||
Set whether to use memory efficient attention from `xformers` or not.
|
||||
|
||||
Args:
|
||||
@@ -248,7 +259,7 @@ class AttentionModuleMixin:
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_attention_slice(self, slice_size: int) -> None:
|
||||
r"""
|
||||
"""
|
||||
Set the slice size for attention computation.
|
||||
|
||||
Args:
|
||||
@@ -278,7 +289,7 @@ class AttentionModuleMixin:
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_processor(self, processor: "AttnProcessor") -> None:
|
||||
r"""
|
||||
"""
|
||||
Set the attention processor to use.
|
||||
|
||||
Args:
|
||||
@@ -298,7 +309,7 @@ class AttentionModuleMixin:
|
||||
self.processor = processor
|
||||
|
||||
def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
|
||||
r"""
|
||||
"""
|
||||
Get the attention processor in use.
|
||||
|
||||
Args:
|
||||
@@ -312,9 +323,8 @@ class AttentionModuleMixin:
|
||||
return self.processor
|
||||
|
||||
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
|
||||
is the number of heads initialized while constructing the `Attention` class.
|
||||
"""
|
||||
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): The tensor to reshape.
|
||||
@@ -329,14 +339,12 @@ class AttentionModuleMixin:
|
||||
return tensor
|
||||
|
||||
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
||||
r"""
|
||||
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
|
||||
the number of heads initialized while constructing the `Attention` class.
|
||||
"""
|
||||
Reshape the tensor for multi-head attention processing.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): The tensor to reshape.
|
||||
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
|
||||
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
|
||||
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The reshaped tensor.
|
||||
@@ -358,13 +366,13 @@ class AttentionModuleMixin:
|
||||
def get_attention_scores(
|
||||
self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
"""
|
||||
Compute the attention scores.
|
||||
|
||||
Args:
|
||||
query (`torch.Tensor`): The query tensor.
|
||||
key (`torch.Tensor`): The key tensor.
|
||||
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
||||
attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The attention probabilities/scores.
|
||||
@@ -405,18 +413,14 @@ class AttentionModuleMixin:
|
||||
def prepare_attention_mask(
|
||||
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
"""
|
||||
Prepare the attention mask for the attention computation.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
The attention mask to prepare.
|
||||
target_length (`int`):
|
||||
The target length of the attention mask. This is the length of the attention mask after padding.
|
||||
batch_size (`int`):
|
||||
The batch size, which is used to repeat the attention mask.
|
||||
out_dim (`int`, *optional*, defaults to `3`):
|
||||
The output dimension of the attention mask. Can be either `3` or `4`.
|
||||
attention_mask (`torch.Tensor`): The attention mask to prepare.
|
||||
target_length (`int`): The target length of the attention mask.
|
||||
batch_size (`int`): The batch size for repeating the attention mask.
|
||||
out_dim (`int`, *optional*, defaults to `3`): Output dimension.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The prepared attention mask.
|
||||
@@ -450,9 +454,8 @@ class AttentionModuleMixin:
|
||||
return attention_mask
|
||||
|
||||
def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
||||
`Attention` class.
|
||||
"""
|
||||
Normalize the encoder hidden states.
|
||||
|
||||
Args:
|
||||
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
||||
@@ -461,7 +464,6 @@ class AttentionModuleMixin:
|
||||
`torch.Tensor`: The normalized encoder hidden states.
|
||||
"""
|
||||
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
||||
|
||||
if isinstance(self.norm_cross, nn.LayerNorm):
|
||||
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
||||
elif isinstance(self.norm_cross, nn.GroupNorm):
|
||||
@@ -480,6 +482,12 @@ class AttentionModuleMixin:
|
||||
|
||||
@torch.no_grad()
|
||||
def fuse_projections(self, fuse=True):
|
||||
"""
|
||||
Fuse the query, key, and value projections into a single projection for efficiency.
|
||||
|
||||
Args:
|
||||
fuse (`bool`): Whether to fuse the projections or not.
|
||||
"""
|
||||
device = self.to_q.weight.data.device
|
||||
dtype = self.to_q.weight.data.dtype
|
||||
|
||||
@@ -4534,7 +4542,7 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDiffusionAttnProcessor2_0(nn.Module):
|
||||
class CustomDiffusionAttnProcessorSDPA(nn.Module):
|
||||
r"""
|
||||
Processor for implementing attention for the Custom Diffusion method using PyTorch 2.0’s memory-efficient scaled
|
||||
dot-product attention.
|
||||
@@ -5056,13 +5064,6 @@ class IPAdapterAttnProcessor(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAdapterAttnProcessor2_0(IPAdapterAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
deprecation_message = "`IPAdapterAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `IPAdapterAttnProcessorSDPA`"
|
||||
deprecate("IPAdapterAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class IPAdapterAttnProcessorSDPA(torch.nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapter for PyTorch 2.0.
|
||||
@@ -5527,7 +5528,7 @@ class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
||||
class SD3IPAdapterJointAttnProcessorSDPA(torch.nn.Module):
|
||||
"""
|
||||
Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with
|
||||
additional image-based information and timestep embeddings.
|
||||
@@ -5996,13 +5997,13 @@ class LoRAAttnAddedKVProcessor:
|
||||
pass
|
||||
|
||||
|
||||
class FluxSingleAttnProcessor2_0(FluxAttnProcessorSDPA):
|
||||
class FluxSingleAttnProcessor2_0(FluxSingleAttnProcessorSDPA):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
|
||||
deprecation_message = "`FluxSingleAttnProcessorSDPA` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
|
||||
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
|
||||
super().__init__()
|
||||
|
||||
@@ -6171,288 +6172,290 @@ class PAGIdentitySanaLinearAttnProcessorSDPA:
|
||||
|
||||
class MochiAttnProcessor2_0(MochiAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`MochiAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `MochiAttnProcessorSDPA`"
|
||||
deprecation_message = "`MochiAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `MochiAttnProcessorSDPA`"
|
||||
deprecate("MochiAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class AttnAddedKVProcessor2_0(AttnAddedKVProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`AttnAddedKVAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
|
||||
deprecation_message = "`AttnAddedKVAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `AttnAddedKVProcessorSDPA`"
|
||||
deprecate("AttnAddedKVAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class JointAttnProcessor2_0(JointAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`JointAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `JointAttnProcessorSDPA`"
|
||||
deprecation_message = "`JointAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `JointAttnProcessorSDPA`"
|
||||
deprecate("JointAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGJointAttnProcessor2_0(PAGJointAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`PAGJointAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGJointAttnProcessorSDPA`"
|
||||
deprecation_message = "`PAGJointAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGJointAttnProcessorSDPA`"
|
||||
deprecate("PAGJointAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGCFGJointAttnProcessor2_0(PAGCFGJointAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`PAGCFGJointAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGJointAttnProcessorSDPA`"
|
||||
deprecation_message = "`PAGCFGJointAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGCFGJointAttnProcessorSDPA`"
|
||||
deprecate("PAGCFGJointAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedJointAttnProcessor2_0(FusedJointAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`FusedJointAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedJointAttnProcessorSDPA`"
|
||||
deprecation_message = "`FusedJointAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedJointAttnProcessorSDPA`"
|
||||
deprecate("FusedJointAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class AllegroAttnProcessor2_0(AllegroAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`AllegroAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `AllegroAttnProcessorSDPA`"
|
||||
deprecation_message = "`AllegroAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `AllegroAttnProcessorSDPA`"
|
||||
deprecate("AllegroAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class AuraFlowAttnProcessor2_0(AuraFlowAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`AuraFlowAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`"
|
||||
deprecation_message = "`AuraFlowAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `AuraFlowAttnProcessorSDPA`"
|
||||
deprecate("AuraFlowAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedAuraFlowAttnProcessor2_0(FusedAuraFlowAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`FusedAuraFlowAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedAuraFlowAttnProcessorSDPA`"
|
||||
deprecation_message = "`FusedAuraFlowAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedAuraFlowAttnProcessorSDPA`"
|
||||
deprecate("FusedAuraFlowAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class FluxAttnProcessor2_0(FluxAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`FluxAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessorSDPA`"
|
||||
deprecation_message = "`FluxAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessorSDPA`"
|
||||
deprecate("FluxAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0(FusedFluxAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`FusedFluxAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedFluxAttnProcessorSDPA`"
|
||||
deprecation_message = "`FusedFluxAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedFluxAttnProcessorSDPA`"
|
||||
deprecate("FusedFluxAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0(CogVideoXAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`CogVideoXAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `CogVideoXAttnProcessorSDPA`"
|
||||
deprecation_message = "`CogVideoXAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `CogVideoXAttnProcessorSDPA`"
|
||||
deprecate("CogVideoXAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedCogVideoXAttnProcessor2_0(FusedCogVideoXAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`FusedCogVideoXAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedCogVideoXAttnProcessorSDPA`"
|
||||
deprecation_message = "`FusedCogVideoXAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedCogVideoXAttnProcessorSDPA`"
|
||||
deprecate("FusedCogVideoXAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class AttnProcessor2_0(AttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`AttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
|
||||
deprecation_message = "`AttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `AttnProcessorSDPA`"
|
||||
deprecate("AttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class XLAFlashAttnProcessor2_0(XLAFlashAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`XLAFlashAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFlashAttnProcessorSDPA`"
|
||||
deprecation_message = "`XLAFlashAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `XLAFlashAttnProcessorSDPA`"
|
||||
deprecate("XLAFlashAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class XLAFluxFlashAttnProcessor2_0(XLAFluxFlashAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`XLAFluxFlashAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `XLAFluxFlashAttnProcessorSDPA`"
|
||||
deprecation_message = "`XLAFluxFlashAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `XLAFluxFlashAttnProcessorSDPA`"
|
||||
deprecate("XLAFluxFlashAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class MochiVaeAttnProcessor2_0(MochiVaeAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`MochiVaeAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `MochiVaeAttnProcessorSDPA`"
|
||||
deprecation_message = "`MochiVaeAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `MochiVaeAttnProcessorSDPA`"
|
||||
deprecate("MochiVaeAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class StableAudioAttnProcessor2_0(StableAudioAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`StableAudioAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `StableAudioAttnProcessorSDPA`"
|
||||
deprecation_message = "`StableAudioAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `StableAudioAttnProcessorSDPA`"
|
||||
deprecate("StableAudioAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class HunyuanAttnProcessor2_0(HunyuanAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`HunyuanAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `HunyuanAttnProcessorSDPA`"
|
||||
deprecation_message = "`HunyuanAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `HunyuanAttnProcessorSDPA`"
|
||||
deprecate("HunyuanAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedHunyuanAttnProcessor2_0(FusedHunyuanAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`FusedHunyuanAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedHunyuanAttnProcessorSDPA`"
|
||||
deprecation_message = "`FusedHunyuanAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedHunyuanAttnProcessorSDPA`"
|
||||
deprecate("FusedHunyuanAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGHunyuanAttnProcessor2_0(PAGHunyuanAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`PAGHunyuanAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGHunyuanAttnProcessorSDPA`"
|
||||
deprecation_message = "`PAGHunyuanAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGHunyuanAttnProcessorSDPA`"
|
||||
deprecate("PAGHunyuanAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGCFGHunyuanAttnProcessor2_0(PAGCFGHunyuanAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`PAGCFGHunyuanAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGHunyuanAttnProcessorSDPA`"
|
||||
deprecation_message = "`PAGCFGHunyuanAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGCFGHunyuanAttnProcessorSDPA`"
|
||||
deprecate("PAGCFGHunyuanAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class LuminaAttnProcessor2_0(LuminaAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`LuminaAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LuminaAttnProcessorSDPA`"
|
||||
deprecation_message = "`LuminaAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `LuminaAttnProcessorSDPA`"
|
||||
deprecate("LuminaAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class FusedAttnProcessor2_0(FusedAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`FusedAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `FusedAttnProcessorSDPA`"
|
||||
deprecation_message = "`FusedAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `FusedAttnProcessorSDPA`"
|
||||
deprecate("FusedAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGIdentitySelfAttnProcessor2_0(PAGIdentitySelfAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`PAGIdentitySelfAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySelfAttnProcessorSDPA`"
|
||||
deprecation_message = "`PAGIdentitySelfAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGIdentitySelfAttnProcessorSDPA`"
|
||||
deprecate("PAGIdentitySelfAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGCFGIdentitySelfAttnProcessor2_0(PAGCFGIdentitySelfAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`PAGCFGIdentitySelfAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGIdentitySelfAttnProcessorSDPA`"
|
||||
deprecation_message = "`PAGCFGIdentitySelfAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGCFGIdentitySelfAttnProcessorSDPA`"
|
||||
deprecate("PAGCFGIdentitySelfAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SanaMultiscaleAttnProcessor2_0(SanaMultiscaleAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`SanaMultiscaleAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaMultiscaleAttnProcessorSDPA`"
|
||||
deprecation_message = "`SanaMultiscaleAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `SanaMultiscaleAttnProcessorSDPA`"
|
||||
deprecate("SanaMultiscaleAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class LoRAAttnProcessor2_0(LoRAAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`LoRAAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LoRAAttnProcessorSDPA`"
|
||||
deprecation_message = "`LoRAAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `LoRAAttnProcessorSDPA`"
|
||||
deprecate("LoRAAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SanaLinearAttnProcessor2_0(SanaLinearAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`SanaLinearAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `SanaLinearAttnProcessorSDPA`"
|
||||
deprecation_message = "`SanaLinearAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `SanaLinearAttnProcessorSDPA`"
|
||||
deprecate("SanaLinearAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGCFGSanaLinearAttnProcessor2_0(PAGCFGSanaLinearAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`PAGCFGSanaLinearAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGCFGSanaLinearAttnProcessorSDPA`"
|
||||
deprecation_message = "`PAGCFGSanaLinearAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGCFGSanaLinearAttnProcessorSDPA`"
|
||||
deprecate("PAGCFGSanaLinearAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PAGIdentitySanaLinearAttnProcessor2_0(PAGIdentitySanaLinearAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "`PAGIdentitySanaLinearAttnAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `PAGIdentitySanaLinearAttnProcessorSDPA`"
|
||||
deprecation_message = "`PAGIdentitySanaLinearAttnAttentionProcessorSDPA` is deprecated and this will be removed in a future version. Please use `PAGIdentitySanaLinearAttnProcessorSDPA`"
|
||||
deprecate("PAGIdentitySanaLinearAttnAttentionProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class IPAdapterAttnProcessor2_0(IPAdapterAttnProcessorSDPA):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
deprecation_message = "`IPAdapterAttnProcessorSDPA` is deprecated and this will be removed in a future version. Please use `IPAdapterAttnProcessorSDPA`"
|
||||
deprecate("IPAdapterAttnProcessor2_0", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
ADDED_KV_ATTENTION_PROCESSORS = (
|
||||
AttnAddedKVProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
AttnAddedKVProcessorSDPA,
|
||||
XFormersAttnAddedKVProcessor,
|
||||
)
|
||||
|
||||
CROSS_ATTENTION_PROCESSORS = (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
AttnProcessorSDPA,
|
||||
XFormersAttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
IPAdapterAttnProcessorSDPA,
|
||||
FluxIPAdapterJointAttnProcessorSDPA,
|
||||
)
|
||||
|
||||
AttentionProcessor = Union[
|
||||
AttnProcessor,
|
||||
CustomDiffusionAttnProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
JointAttnProcessor2_0,
|
||||
PAGJointAttnProcessor2_0,
|
||||
PAGCFGJointAttnProcessor2_0,
|
||||
FusedJointAttnProcessor2_0,
|
||||
AllegroAttnProcessor2_0,
|
||||
AuraFlowAttnProcessor2_0,
|
||||
FusedAuraFlowAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
FusedFluxAttnProcessor2_0_NPU,
|
||||
CogVideoXAttnProcessor2_0,
|
||||
FusedCogVideoXAttnProcessor2_0,
|
||||
AttnAddedKVProcessorSDPA,
|
||||
JointAttnProcessorSDPA,
|
||||
PAGJointAttnProcessorSDPA,
|
||||
PAGCFGJointAttnProcessorSDPA,
|
||||
FusedJointAttnProcessorSDPA,
|
||||
AllegroAttnProcessorSDPA,
|
||||
AuraFlowAttnProcessorSDPA,
|
||||
FusedAuraFlowAttnProcessorSDPA,
|
||||
FluxAttnProcessorSDPA,
|
||||
FluxAttnProcessorSDPA_NPU,
|
||||
FusedFluxAttnProcessorSDPA,
|
||||
FusedFluxAttnProcessorSDPA_NPU,
|
||||
CogVideoXAttnProcessorSDPA,
|
||||
FusedCogVideoXAttnProcessorSDPA,
|
||||
XFormersAttnAddedKVProcessor,
|
||||
XFormersAttnProcessor,
|
||||
XLAFlashAttnProcessor2_0,
|
||||
XLAFlashAttnProcessorSDPA,
|
||||
AttnProcessorNPU,
|
||||
AttnProcessor2_0,
|
||||
MochiVaeAttnProcessor2_0,
|
||||
MochiAttnProcessor2_0,
|
||||
StableAudioAttnProcessor2_0,
|
||||
HunyuanAttnProcessor2_0,
|
||||
FusedHunyuanAttnProcessor2_0,
|
||||
PAGHunyuanAttnProcessor2_0,
|
||||
PAGCFGHunyuanAttnProcessor2_0,
|
||||
LuminaAttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
AttnProcessorSDPA,
|
||||
MochiVaeAttnProcessorSDPA,
|
||||
MochiAttnProcessorSDPA,
|
||||
StableAudioAttnProcessorSDPA,
|
||||
HunyuanAttnProcessorSDPA,
|
||||
FusedHunyuanAttnProcessorSDPA,
|
||||
PAGHunyuanAttnProcessorSDPA,
|
||||
PAGCFGHunyuanAttnProcessorSDPA,
|
||||
LuminaAttnProcessorSDPA,
|
||||
FusedAttnProcessorSDPA,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionAttnProcessorSDPA,
|
||||
SlicedAttnProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
SanaLinearAttnProcessor2_0,
|
||||
PAGCFGSanaLinearAttnProcessor2_0,
|
||||
PAGIdentitySanaLinearAttnProcessor2_0,
|
||||
SanaLinearAttnProcessorSDPA,
|
||||
PAGCFGSanaLinearAttnProcessorSDPA,
|
||||
PAGIdentitySanaLinearAttnProcessorSDPA,
|
||||
SanaMultiscaleLinearAttention,
|
||||
SanaMultiscaleAttnProcessor2_0,
|
||||
SanaMultiscaleAttnProcessorSDPA,
|
||||
SanaMultiscaleAttentionProjection,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterAttnProcessorSDPA,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
SD3IPAdapterJointAttnProcessor2_0,
|
||||
PAGIdentitySelfAttnProcessor2_0,
|
||||
PAGCFGIdentitySelfAttnProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
SD3IPAdapterJointAttnProcessorSDPA,
|
||||
PAGIdentitySelfAttnProcessorSDPA,
|
||||
PAGCFGIdentitySelfAttnProcessorSDPA,
|
||||
LoRAAttnProcessorSDPA,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user