mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
minify and deprecate npu/xla processors
This commit is contained in:
@@ -2272,235 +2272,6 @@ class FusedAuraFlowAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxAttnProcessor2_0_NPU:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
deprecation_message = (
|
||||
"FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
|
||||
"alternative solution to use NPU Flash Attention will be provided in the future."
|
||||
)
|
||||
deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
if query.dtype in (torch.float16, torch.bfloat16):
|
||||
hidden_states = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn.heads,
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0_NPU:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
deprecation_message = (
|
||||
"FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
|
||||
"alternative solution to use NPU Flash Attention will be provided in the future."
|
||||
)
|
||||
deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
(
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
if query.dtype in (torch.float16, torch.bfloat16):
|
||||
hidden_states = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn.heads,
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]),
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
@@ -3130,112 +2901,6 @@ class XLAFlashAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XLAFluxFlashAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
||||
"""
|
||||
|
||||
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
|
||||
deprecation_message = (
|
||||
"XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
|
||||
"alternative solution to using XLA Flash Attention will be provided in the future."
|
||||
)
|
||||
deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
if is_torch_xla_version("<", "2.3"):
|
||||
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
||||
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
||||
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
||||
self.partition_spec = partition_spec
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
query /= math.sqrt(head_dim)
|
||||
hidden_states = flash_attention(query, key, value, causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MochiVaeAttnProcessor2_0:
|
||||
r"""
|
||||
Attention processor used in Mochi VAE.
|
||||
@@ -5883,6 +5548,67 @@ class FluxIPAdapterJointAttnProcessor2_0:
|
||||
return FluxIPAdapterAttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
class FluxAttnProcessor2_0_NPU:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = (
|
||||
"FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
|
||||
"alternative solution to use NPU Flash Attention will be provided in the future."
|
||||
)
|
||||
deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
from .transformers.transformer_flux import FluxAttnProcessor
|
||||
|
||||
processor = FluxAttnProcessor()
|
||||
processor._attention_backend = "_native_npu"
|
||||
return processor
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0_NPU:
|
||||
def __new__(self):
|
||||
deprecation_message = (
|
||||
"FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
|
||||
"alternative solution to use NPU Flash Attention will be provided in the future."
|
||||
)
|
||||
deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
from .transformers.transformer_flux import FluxAttnProcessor
|
||||
|
||||
processor = FluxAttnProcessor()
|
||||
processor._attention_backend = "_fused_npu"
|
||||
return processor
|
||||
|
||||
|
||||
class XLAFluxFlashAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = (
|
||||
"XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
|
||||
"alternative solution to using XLA Flash Attention will be provided in the future."
|
||||
)
|
||||
deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
if is_torch_xla_version("<", "2.3"):
|
||||
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
||||
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
||||
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
||||
|
||||
from .transformers.transformer_flux import FluxAttnProcessor
|
||||
|
||||
if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
|
||||
deprecation_message = (
|
||||
"partition_spec was not used in the processor implementation when it was added. Passing it "
|
||||
"is a no-op and support for it will be removed."
|
||||
)
|
||||
deprecate("partition_spec", "1.0.0", deprecation_message)
|
||||
|
||||
processor = FluxAttnProcessor(*args, **kwargs)
|
||||
processor._attention_backend = "_native_xla"
|
||||
return processor
|
||||
|
||||
|
||||
ADDED_KV_ATTENTION_PROCESSORS = (
|
||||
AttnAddedKVProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
|
||||
Reference in New Issue
Block a user