mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' into ltx-2-transformer
This commit is contained in:
22
.github/workflows/codeql.yml
vendored
Normal file
22
.github/workflows/codeql.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
name: CodeQL Security Analysis For Github Actions
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
# pull_request:
|
||||
|
||||
jobs:
|
||||
codeql:
|
||||
name: CodeQL Analysis
|
||||
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1
|
||||
permissions:
|
||||
security-events: write
|
||||
packages: read
|
||||
actions: read
|
||||
contents: read
|
||||
with:
|
||||
languages: '["actions","python"]'
|
||||
queries: 'security-extended,security-and-quality'
|
||||
runner: 'ubuntu-latest' #optional if need custom runner
|
||||
@@ -1420,6 +1420,7 @@ def _flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -1427,6 +1428,9 @@ def _flash_attention(
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
if _parallel_config is None:
|
||||
out = flash_attn_func(
|
||||
q=query,
|
||||
@@ -1469,6 +1473,7 @@ def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -1476,6 +1481,9 @@ def _flash_attention_hub(
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
@@ -1612,11 +1620,15 @@ def _flash_attention_3(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
out, lse = _wrapped_flash_attn_3(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -1636,6 +1648,7 @@ def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
@@ -1646,6 +1659,8 @@ def _flash_attention_3_hub(
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
@@ -1785,12 +1800,16 @@ def _aiter_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for aiter attention")
|
||||
|
||||
if not return_lse and torch.is_grad_enabled():
|
||||
# aiter requires return_lse=True by assertion when gradients are enabled.
|
||||
out, lse, *_ = aiter_flash_attn_func(
|
||||
@@ -2028,6 +2047,7 @@ def _native_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -2035,6 +2055,9 @@ def _native_flash_attention(
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for aiter attention")
|
||||
|
||||
lse = None
|
||||
if _parallel_config is None and not return_lse:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
@@ -2113,11 +2136,14 @@ def _native_npu_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for NPU attention")
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
|
||||
@@ -2148,10 +2174,13 @@ def _native_xla_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for XLA attention")
|
||||
if return_lse:
|
||||
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
@@ -2175,11 +2204,14 @@ def _sage_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
if _parallel_config is None:
|
||||
out = sageattn(
|
||||
@@ -2223,11 +2255,14 @@ def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
@@ -2309,11 +2344,14 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2333,11 +2371,14 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda_sm90(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2357,11 +2398,14 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2381,11 +2425,14 @@ def _sage_qk_int8_pv_fp16_triton_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_triton(
|
||||
q=query,
|
||||
k=key,
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -717,11 +717,7 @@ class FluxTransformer2DModel(
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -835,14 +835,8 @@ class Flux2Transformer2DModel(
|
||||
if txt_ids.ndim == 3:
|
||||
txt_ids = txt_ids[0]
|
||||
|
||||
if is_torch_npu_available():
|
||||
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
|
||||
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
|
||||
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
|
||||
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(img_ids)
|
||||
text_rotary_emb = self.pos_embed(txt_ids)
|
||||
image_rotary_emb = self.pos_embed(img_ids)
|
||||
text_rotary_emb = self.pos_embed(txt_ids)
|
||||
concat_rotary_emb = (
|
||||
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
|
||||
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import is_torch_npu_available, logging
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -499,11 +499,7 @@ class LongCatImageTransformer2DModel(
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]:
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import is_torch_npu_available, logging
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -530,11 +530,7 @@ class OvisImageTransformer2DModel(
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
Reference in New Issue
Block a user