From ade1059ae2bdfd7b2da79c32aec454894544435e Mon Sep 17 00:00:00 2001 From: zhangtao0408 <365968531@qq.com> Date: Wed, 7 Jan 2026 02:48:04 +0800 Subject: [PATCH] [Flux.1] improve pos embed for ascend npu by computing on npu (#12897) * [Flux.1] improve pos embed for ascend npu by setting it back to npu computation. * [Flux.2] improve pos embed for ascend npu by setting it back to npu computation. * [LongCat-Image] improve pos embed for ascend npu by setting it back to npu computation. * [Ovis-Image] improve pos embed for ascend npu by setting it back to npu computation. * Remove unused import of is_torch_npu_available --------- Co-authored-by: zhangtao --- .../models/transformers/transformer_flux.py | 8 ++------ .../models/transformers/transformer_flux2.py | 12 +++--------- .../models/transformers/transformer_longcat_image.py | 8 ++------ .../models/transformers/transformer_ovis_image.py | 8 ++------ 4 files changed, 9 insertions(+), 27 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 16c526f437..1a44644324 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -22,7 +22,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward @@ -717,11 +717,7 @@ class FluxTransformer2DModel( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) - if is_torch_npu_available(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) + image_rotary_emb = self.pos_embed(ids) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") diff --git a/src/diffusers/models/transformers/transformer_flux2.py b/src/diffusers/models/transformers/transformer_flux2.py index c10bf3ed4f..8032ec48c1 100644 --- a/src/diffusers/models/transformers/transformer_flux2.py +++ b/src/diffusers/models/transformers/transformer_flux2.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin from ..attention_dispatch import dispatch_attention_fn @@ -835,14 +835,8 @@ class Flux2Transformer2DModel( if txt_ids.ndim == 3: txt_ids = txt_ids[0] - if is_torch_npu_available(): - freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu()) - image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu()) - freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu()) - text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu()) - else: - image_rotary_emb = self.pos_embed(img_ids) - text_rotary_emb = self.pos_embed(txt_ids) + image_rotary_emb = self.pos_embed(img_ids) + text_rotary_emb = self.pos_embed(txt_ids) concat_rotary_emb = ( torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0), torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0), diff --git a/src/diffusers/models/transformers/transformer_longcat_image.py b/src/diffusers/models/transformers/transformer_longcat_image.py index 7fbaaa3fee..2696f5e787 100644 --- a/src/diffusers/models/transformers/transformer_longcat_image.py +++ b/src/diffusers/models/transformers/transformer_longcat_image.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import is_torch_npu_available, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -499,11 +499,7 @@ class LongCatImageTransformer2DModel( encoder_hidden_states = self.context_embedder(encoder_hidden_states) ids = torch.cat((txt_ids, img_ids), dim=0) - if is_torch_npu_available(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) + image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]: diff --git a/src/diffusers/models/transformers/transformer_ovis_image.py b/src/diffusers/models/transformers/transformer_ovis_image.py index 0a09aa720b..139ceaefa4 100644 --- a/src/diffusers/models/transformers/transformer_ovis_image.py +++ b/src/diffusers/models/transformers/transformer_ovis_image.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import is_torch_npu_available, logging +from ...utils import logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention import AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn @@ -530,11 +530,7 @@ class OvisImageTransformer2DModel( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) - if is_torch_npu_available(): - freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) - image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) - else: - image_rotary_emb = self.pos_embed(ids) + image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: