1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix wan 2.1 i2v context parallel (#12909)

* fix wan 2.1 i2v context parallel

* fix wan 2.1 i2v context parallel

* fix wan 2.1 i2v context parallel

* format
This commit is contained in:
DefTruth
2026-01-06 10:12:53 +08:00
committed by GitHub
parent 0da1aa90b5
commit 3138e37fe6
2 changed files with 13 additions and 7 deletions

View File

@@ -134,7 +134,8 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=None,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
@@ -147,7 +148,8 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
@@ -552,9 +554,11 @@ class WanTransformer3DModel(
"blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"blocks.*": {
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
# Reference: https://github.com/huggingface/diffusers/pull/12909
# We need to disable the splitting of encoder_hidden_states because the image_encoder
# (Wan 2.1 I2V) consistently generates 257 tokens for image_embed. This causes the shape
# of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation
# —to be indivisible by the number of devices in the CP.
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
"": {
"timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),

View File

@@ -609,7 +609,8 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=None,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
@@ -622,7 +623,8 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)