diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml new file mode 100644 index 0000000000..5ba158b46f --- /dev/null +++ b/.github/workflows/codeql.yml @@ -0,0 +1,22 @@ +--- +name: CodeQL Security Analysis For Github Actions + +on: + push: + branches: ["main"] + workflow_dispatch: + # pull_request: + +jobs: + codeql: + name: CodeQL Analysis + uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1 + permissions: + security-events: write + packages: read + actions: read + contents: read + with: + languages: '["actions","python"]' + queries: 'security-extended,security-and-quality' + runner: 'ubuntu-latest' #optional if need custom runner diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 310c44457c..15516ed2ed 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1420,6 +1420,7 @@ def _flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1427,6 +1428,9 @@ def _flash_attention( _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: lse = None + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 2.") + if _parallel_config is None: out = flash_attn_func( q=query, @@ -1469,6 +1473,7 @@ def _flash_attention_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -1476,6 +1481,9 @@ def _flash_attention_hub( _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: lse = None + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 2.") + func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn out = func( q=query, @@ -1612,11 +1620,15 @@ def _flash_attention_3( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 3.") + out, lse = _wrapped_flash_attn_3( q=query, k=key, @@ -1636,6 +1648,7 @@ def _flash_attention_3_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, scale: Optional[float] = None, is_causal: bool = False, window_size: Tuple[int, int] = (-1, -1), @@ -1646,6 +1659,8 @@ def _flash_attention_3_hub( ) -> torch.Tensor: if _parallel_config: raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.") + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for flash-attn 3.") func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn out = func( @@ -1785,12 +1800,16 @@ def _aiter_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for aiter attention") + if not return_lse and torch.is_grad_enabled(): # aiter requires return_lse=True by assertion when gradients are enabled. out, lse, *_ = aiter_flash_attn_func( @@ -2028,6 +2047,7 @@ def _native_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, @@ -2035,6 +2055,9 @@ def _native_flash_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for aiter attention") + lse = None if _parallel_config is None and not return_lse: query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) @@ -2113,11 +2136,14 @@ def _native_npu_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, dropout_p: float = 0.0, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for NPU attention") if return_lse: raise ValueError("NPU attention backend does not support setting `return_lse=True`.") query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) @@ -2148,10 +2174,13 @@ def _native_xla_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for XLA attention") if return_lse: raise ValueError("XLA attention backend does not support setting `return_lse=True`.") query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) @@ -2175,11 +2204,14 @@ def _sage_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") lse = None if _parallel_config is None: out = sageattn( @@ -2223,11 +2255,14 @@ def _sage_attention_hub( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") lse = None func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn if _parallel_config is None: @@ -2309,11 +2344,14 @@ def _sage_qk_int8_pv_fp8_cuda_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp8_cuda( q=query, k=key, @@ -2333,11 +2371,14 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp8_cuda_sm90( q=query, k=key, @@ -2357,11 +2398,14 @@ def _sage_qk_int8_pv_fp16_cuda_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp16_cuda( q=query, k=key, @@ -2381,11 +2425,14 @@ def _sage_qk_int8_pv_fp16_triton_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, is_causal: bool = False, scale: Optional[float] = None, return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: + if attn_mask is not None: + raise ValueError("`attn_mask` is not supported for sage attention") return sageattn_qk_int8_pv_fp16_triton( q=query, k=key, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 16c526f437..1a44644324 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -717,11 +717,7 @@ class FluxTransformer2DModel( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) - if is_torch_npu_available(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) + image_rotary_emb = self.pos_embed(ids) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index c10bf3ed4f..8032ec48c1 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn @@ -835,14 +835,8 @@ class Flux2Transformer2DModel( if txt_ids.ndim == 3: txt_ids = txt_ids[0] - if is_torch_npu_available(): - freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu()) - image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) - freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu()) - text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu()) - else: - image_rotary_emb = self.pos_embed(img_ids) - text_rotary_emb = self.pos_embed(txt_ids) + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) concat_rotary_emb = ( torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index 7fbaaa3fee..2696f5e787 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import is_torch_npu_available, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -499,11 +499,7 @@ class LongCatImageTransformer2DModel( encoder_hidden_states = self.context_embedder(encoder_hidden_states) ids = torch.cat((txt_ids, img_ids), dim=0) - if is_torch_npu_available(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) + image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]: diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py index 0a09aa720b..139ceaefa4 100644 --- a/src/diffusers/models/transformers/transformer_ovis_image.py +++ b/src/diffusers/models/transformers/transformer_ovis_image.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import is_torch_npu_available, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -530,11 +530,7 @@ class OvisImageTransformer2DModel( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) - if is_torch_npu_available(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) + image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: