mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix wan 2.1 i2v context parallel (#12909)
* fix wan 2.1 i2v context parallel * fix wan 2.1 i2v context parallel * fix wan 2.1 i2v context parallel * format
This commit is contained in:
@@ -134,7 +134,8 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12909
|
||||
parallel_config=None,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
@@ -147,7 +148,8 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12909
|
||||
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
@@ -552,9 +554,11 @@ class WanTransformer3DModel(
|
||||
"blocks.0": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.*": {
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12909
|
||||
# We need to disable the splitting of encoder_hidden_states because the image_encoder
|
||||
# (Wan 2.1 I2V) consistently generates 257 tokens for image_embed. This causes the shape
|
||||
# of encoder_hidden_states—whose token count is always 769 (512 + 257) after concatenation
|
||||
# —to be indivisible by the number of devices in the CP.
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
"": {
|
||||
"timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
|
||||
@@ -609,7 +609,8 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12909
|
||||
parallel_config=None,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
@@ -622,7 +623,8 @@ class WanAttnProcessor:
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12909
|
||||
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
Reference in New Issue
Block a user