From 7f43cb1d7919e69d71516fb088db066f6fb7aa7d Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Thu, 15 Jan 2026 18:10:24 +0800 Subject: [PATCH] fix Qwen-Image series context parallel (#12970) * fix qwen-image cp * relax attn_mask limit for cp * CP plan compatible with zero_cond_t * move modulate_index plan to top level --- src/diffusers/models/attention_dispatch.py | 2 -- src/diffusers/models/transformers/transformer_qwenimage.py | 7 +++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index f4ec497038..f086c2d425 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -1573,8 +1573,6 @@ def _templated_context_parallel_attention( backward_op, _parallel_config: Optional["ParallelConfig"] = None, ): - if attn_mask is not None: - raise ValueError("Attention mask is not yet supported for templated attention.") if is_causal: raise ValueError("Causal attention is not yet supported for templated attention.") if enable_gqa: diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index a8c98201d9..cf11d8e01f 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -761,11 +761,14 @@ class QwenImageTransformer2DModel( _no_split_modules = ["QwenImageTransformerBlock"] _skip_layerwise_casting_patterns = ["pos_embed", "norm"] _repeated_blocks = ["QwenImageTransformerBlock"] + # Make CP plan compatible with https://github.com/huggingface/diffusers/pull/12702 _cp_plan = { - "": { + "transformer_blocks.0": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), + }, + "transformer_blocks.*": { + "modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), }, "pos_embed": { 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),