1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix Qwen-Image series context parallel (#12970)

* fix qwen-image cp

* relax attn_mask limit for cp

* CP plan compatible with zero_cond_t

* move modulate_index plan to top level
This commit is contained in:
DefTruth
2026-01-15 18:10:24 +08:00
committed by GitHub
parent 5efb81fa71
commit 7f43cb1d79
2 changed files with 5 additions and 4 deletions

View File

@@ -1573,8 +1573,6 @@ def _templated_context_parallel_attention(
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("Attention mask is not yet supported for templated attention.")
if is_causal:
raise ValueError("Causal attention is not yet supported for templated attention.")
if enable_gqa:

View File

@@ -761,11 +761,14 @@ class QwenImageTransformer2DModel(
_no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["QwenImageTransformerBlock"]
# Make CP plan compatible with https://github.com/huggingface/diffusers/pull/12702
_cp_plan = {
"": {
"transformer_blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"transformer_blocks.*": {
"modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"pos_embed": {
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),