From 3138e37fe62429cdc26ed03097436a2fc7ccb54e Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Tue, 6 Jan 2026 10:12:53 +0800 Subject: [PATCH] Fix wan 2.1 i2v context parallel (#12909) * fix wan 2.1 i2v context parallel * fix wan 2.1 i2v context parallel * fix wan 2.1 i2v context parallel * format --- .../models/transformers/transformer_wan.py | 14 +++++++++----- .../models/transformers/transformer_wan_animate.py | 6 ++++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index f7693ec5d3..132f615f21 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -134,7 +134,8 @@ class WanAttnProcessor: dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + parallel_config=None, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -147,7 +148,8 @@ class WanAttnProcessor: dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + parallel_config=(self._parallel_config if encoder_hidden_states is None else None), ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -552,9 +554,11 @@ class WanTransformer3DModel( "blocks.0": { "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), }, - "blocks.*": { - "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False), - }, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + # We need to disable the splitting of encoder_hidden_states because the image_encoder + # (Wan 2.1 I2V) consistently generates 257 tokens for image_embed. This causes the shape + # of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation + # —to be indivisible by the number of devices in the CP. "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3), "": { "timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False), diff --git a/src/diffusers/models/transformers/transformer_wan_animate.py b/src/diffusers/models/transformers/transformer_wan_animate.py index 6a47a67385..8860f4bca9 100644 --- a/src/diffusers/models/transformers/transformer_wan_animate.py +++ b/src/diffusers/models/transformers/transformer_wan_animate.py @@ -609,7 +609,8 @@ class WanAttnProcessor: dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + parallel_config=None, ) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) @@ -622,7 +623,8 @@ class WanAttnProcessor: dropout_p=0.0, is_causal=False, backend=self._attention_backend, - parallel_config=self._parallel_config, + # Reference: https://github.com/huggingface/diffusers/pull/12909 + parallel_config=(self._parallel_config if encoder_hidden_states is None else None), ) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query)